diff --git a/.github/CODEOWNERS b/.github/CODEOWNERS new file mode 100644 index 0000000000000000000000000000000000000000..30f0dedd8d26e2282d92a6a8caee7a4529915664 --- /dev/null +++ b/.github/CODEOWNERS @@ -0,0 +1,6 @@ +# Documentation files +docs/* @saadrahim @LisaDelaney +*.md @saadrahim @LisaDelaney +*.rst @saadrahim @LisaDelaney +# Header directory +library/include/* @saadrahim @LisaDelaney diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml new file mode 100644 index 0000000000000000000000000000000000000000..d6700ae05b51371a3d1ac50ba0d701e7384698a9 --- /dev/null +++ b/.pre-commit-config.yaml @@ -0,0 +1,14 @@ +repos: +- repo: local + hooks: + - id: clang-format + name: clang-format + entry: clang-format-12 -i --style=file + language: system + types_or: [c++, inc] + - id: copyright-year-checker + name: copyright-year-checker + entry: script/check_copyright_year.sh + verbose: false + language: script + types: [c++] diff --git a/CHANGELOG.md b/CHANGELOG.md index 01883500465583097aa68642e89a89a06b6750b2..31f129b58159076c3d3027e3c36ccc885e4b1baf 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -12,6 +12,11 @@ Full documentation for Composable Kernel is not yet available. - Improve proformance of normalization kernel ### Added +- Added new cmake flag "DL_KERNELS" must be set to "ON" in order to build the gemm_dl and batched_gemm_multi_d_dl instances. +- Added new cmake flag "DTYPES" which could be set to any subset of "fp64;fp32;fp16;fp8;bf16;int8" to build instance of select data types. +- Added new cmake flag "INSTANCES_ONLY" which will only build CK library and instances without the tests, examples, or profiler. +- Added new feature: if GPU_TARGETS is not set on cmake command line, CK will be built for all targets supported by compiler. +- Added support on MI300A/MI300X. - Added support on NAVI3x. - Added user tutorial (#563). - Added more instances for irregular GEMM sizes (#560). @@ -20,6 +25,8 @@ Full documentation for Composable Kernel is not yet available. - Added multi-embeddings support (#542). - Added Navi3x blockwise GEMM and real GEMM support (#541). - Added Navi grouped ConvBwdWeight support (#505). +- Added MaxPool, AvgPool forward (#815). +- Added MaxPool backward (#750). ### Changed - Changed ... diff --git a/CMakeLists.txt b/CMakeLists.txt index c9fb6b4552c81ef29e77dcde807f552b42916572..4ca54d847d9cc9feb4136db8c591f7b7edbae2c6 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -1,10 +1,61 @@ cmake_minimum_required(VERSION 3.14) +set(version 1.1.0) # Check support for CUDA/HIP in Cmake -project(composable_kernel) +project(composable_kernel VERSION ${version}) list(APPEND CMAKE_MODULE_PATH "${PROJECT_SOURCE_DIR}/cmake") +if (DTYPES) + add_definitions(-DDTYPES) + if (DTYPES MATCHES "int8") + add_definitions(-DCK_ENABLE_INT8) + set(CK_ENABLE_INT8 "ON") + endif() + if (DTYPES MATCHES "fp8") + add_definitions(-DCK_ENABLE_FP8) + set(CK_ENABLE_FP8 "ON") + endif() + if (DTYPES MATCHES "fp16") + add_definitions(-DCK_ENABLE_FP16) + set(CK_ENABLE_FP16 "ON") + endif() + if (DTYPES MATCHES "fp32") + add_definitions(-DCK_ENABLE_FP32) + set(CK_ENABLE_FP32 "ON") + endif() + if (DTYPES MATCHES "fp64") + add_definitions(-DCK_ENABLE_FP64) + set(CK_ENABLE_FP64 "ON") + endif() + if (DTYPES MATCHES "bf16") + add_definitions(-DCK_ENABLE_BF16) + set(CK_ENABLE_BF16 "ON") + endif() + message("DTYPES macro set to ${DTYPES}") +else() + add_definitions(-DCK_ENABLE_INT8 -DCK_ENABLE_FP8 -DCK_ENABLE_FP16 -DCK_ENABLE_FP32 -DCK_ENABLE_FP64 -DCK_ENABLE_BF16) + set(CK_ENABLE_ALL_DTYPES "ON") +endif() + +if(DL_KERNELS) + add_definitions(-DDL_KERNELS) + set(CK_ENABLE_DL_KERNELS "ON") +endif() + +if(INSTANCES_ONLY) + add_definitions(-DINSTANCES_ONLY) + set(CK_ENABLE_INSTANCES_ONLY "ON") +endif() + +# CK config file to record supported datatypes, etc. +configure_file("${PROJECT_SOURCE_DIR}/include/ck/config.h.in" "${PROJECT_BINARY_DIR}/include/ck/config.h") + +# CK version file to record release version as well as git commit hash +find_package(Git REQUIRED) +execute_process(COMMAND "${GIT_EXECUTABLE}" rev-parse HEAD OUTPUT_VARIABLE COMMIT_ID OUTPUT_STRIP_TRAILING_WHITESPACE) +configure_file("${PROJECT_SOURCE_DIR}/include/ck/version.h.in" "${PROJECT_BINARY_DIR}/include/ck/version.h") + enable_testing() set(ROCM_SYMLINK_LIBS OFF) @@ -16,11 +67,57 @@ include(ROCMSetupVersion) include(ROCMInstallSymlinks) include(ROCMCreatePackage) include(CheckCXXCompilerFlag) - -rocm_setup_version(VERSION 0.2.0) +include(ROCMCheckTargetIds) include(TargetFlags) + +rocm_setup_version(VERSION ${version}) + list(APPEND CMAKE_PREFIX_PATH ${CMAKE_INSTALL_PREFIX} ${CMAKE_INSTALL_PREFIX}/llvm ${CMAKE_INSTALL_PREFIX}/hip /opt/rocm /opt/rocm/llvm /opt/rocm/hip) +message("GPU_TARGETS= ${GPU_TARGETS}") + +message("checking which targets are supported") +#This is the list of targets to be used in case GPU_TARGETS is not set on command line +#These targets will be filtered and only supported ones will be used +#Setting GPU_TARGETS on command line will override this list +if(NOT PROFILER_ONLY) + rocm_check_target_ids(DEFAULT_GPU_TARGETS + TARGETS "gfx900;gfx906;gfx908;gfx90a;gfx940;gfx941;gfx942;gfx1030;gfx1100;gfx1101;gfx1102") +else() + add_definitions(-DPROFILER_ONLY) + if(GPU_TARGETS) + message(FATAL_ERROR "For PROFILE_ONLY build, please do not set GPU_TARGETS, use GPU_ARCH = gfx9, gfx10, or gfx11") + endif() + if(GPU_ARCH MATCHES "gfx9") + rocm_check_target_ids(DEFAULT_GPU_TARGETS TARGETS "gfx900;gfx906;gfx908;gfx90a;gfx940;gfx941;gfx942") + elseif(GPU_ARCH MATCHES "gfx10") + rocm_check_target_ids(DEFAULT_GPU_TARGETS TARGETS "gfx1030") + elseif(GPU_ARCH MATCHES "gfx11") + rocm_check_target_ids(DEFAULT_GPU_TARGETS TARGETS "gfx1100;gfx1101;gfx1102") + else() + message(FATAL_ERROR "For PROFILE_ONLY build, please specify GPU_ARCH as gfx9, gfx10, or gfx11") + endif() +endif() + +message("Supported GPU_TARGETS= ${DEFAULT_GPU_TARGETS}") + +set(AMDGPU_TARGETS "${DEFAULT_GPU_TARGETS}" CACHE STRING " ") + +if(GPU_TARGETS) + message("Building CK for the following targets: ${GPU_TARGETS}") +else() + message("Building CK for the following targets: ${AMDGPU_TARGETS}") +endif() +find_package(hip) +# No assumption that HIP kernels are launched with uniform block size for backward compatibility +# SWDEV-413293 and https://reviews.llvm.org/D155213 +math(EXPR hip_VERSION_FLAT "(${hip_VERSION_MAJOR} * 1000 + ${hip_VERSION_MINOR}) * 100000 + ${hip_VERSION_PATCH}") +message("hip_version_flat=${hip_VERSION_FLAT}") +if(${hip_VERSION_FLAT} GREATER 500723302) + message("Adding the fno-offload-uniform-block compiler flag") + add_compile_options(-fno-offload-uniform-block) +endif() + option(USE_BITINT_EXTENSION_INT4, "Whether to enable clang's BitInt extension to provide int4 data type." OFF) option(USE_OPT_NAVI3X, "Whether to enable LDS cumode and Wavefront32 mode for NAVI3X silicons." OFF) @@ -238,13 +335,14 @@ set(CMAKE_LIBRARY_OUTPUT_DIRECTORY ${CMAKE_CURRENT_BINARY_DIR}/lib) set(CMAKE_ARCHIVE_OUTPUT_DIRECTORY ${CMAKE_CURRENT_BINARY_DIR}/lib) set(CMAKE_RUNTIME_OUTPUT_DIRECTORY ${CMAKE_CURRENT_BINARY_DIR}/bin) +# set CK project include directories include_directories(BEFORE + ${PROJECT_BINARY_DIR}/include ${PROJECT_SOURCE_DIR}/include ${PROJECT_SOURCE_DIR}/library/include ${HIP_INCLUDE_DIRS} ) - SET(BUILD_DEV ON CACHE BOOL "BUILD_DEV") if(BUILD_DEV) add_compile_options(-Werror) @@ -258,36 +356,80 @@ file(GLOB_RECURSE INSTANCE_FILES "${PROJECT_SOURCE_DIR}/*/device_*_instance.cpp" file(GLOB dir_list RELATIVE ${PROJECT_SOURCE_DIR}/library/src/tensor_operation_instance/gpu ${PROJECT_SOURCE_DIR}/library/src/tensor_operation_instance/gpu/*) set(CK_DEVICE_INSTANCES) FOREACH(subdir_path ${dir_list}) - IF(IS_DIRECTORY "${PROJECT_SOURCE_DIR}/library/src/tensor_operation_instance/gpu/${subdir_path}") - list(APPEND CK_DEVICE_INSTANCES device_${subdir_path}_instance) - ENDIF() +set(target_dir) +IF(IS_DIRECTORY "${PROJECT_SOURCE_DIR}/library/src/tensor_operation_instance/gpu/${subdir_path}") + set(cmake_instance) + file(READ "${PROJECT_SOURCE_DIR}/library/src/tensor_operation_instance/gpu/${subdir_path}/CMakeLists.txt" cmake_instance) + set(add_inst 0) + if("${cmake_instance}" MATCHES "DTYPES MATCHES \"fp8\" " AND DTYPES MATCHES "fp8") + #message("fp8 instance found!") + set(add_inst 1) + endif() + if("${cmake_instance}" MATCHES "DTYPES MATCHES \"fp16\"" AND DTYPES MATCHES "fp16") + #message("fp16 instance found!") + set(add_inst 1) + endif() + if("${cmake_instance}" MATCHES "DTYPES MATCHES \"fp32\"" AND DTYPES MATCHES "fp32") + #message("fp32 instance found!") + set(add_inst 1) + endif() + if("${cmake_instance}" MATCHES "DTYPES MATCHES \"fp64\"" AND DTYPES MATCHES "fp64") + #message("fp64 instance found!") + set(add_inst 1) + endif() + if("${cmake_instance}" MATCHES "DTYPES MATCHES \"bf16\"" AND DTYPES MATCHES "bf16") + #message("bf16 instance found!") + set(add_inst 1) + endif() + if("${cmake_instance}" MATCHES "DTYPES MATCHES \"int8\"" AND DTYPES MATCHES "int8") + #message("int8 instance found!") + set(add_inst 1) + endif() + if(NOT "${cmake_instance}" MATCHES "DTYPES") + #message("instance should be built for all types!") + set(add_inst 1) + endif() + if(add_inst EQUAL 1 OR NOT DEFINED DTYPES) + list(APPEND CK_DEVICE_INSTANCES device_${subdir_path}_instance) + endif() +ENDIF() ENDFOREACH() + add_custom_target(instances DEPENDS utility;${CK_DEVICE_INSTANCES} SOURCES ${INSTANCE_FILES}) +add_subdirectory(library) -rocm_package_setup_component(tests +if(NOT DEFINED INSTANCES_ONLY) + if(NOT DEFINED PROFILER_ONLY) + rocm_package_setup_component(tests LIBRARY_NAME composablekernel PACKAGE_NAME tests # Prevent -static suffix on package name -) + ) -rocm_package_setup_component(examples + rocm_package_setup_component(examples LIBRARY_NAME composablekernel PACKAGE_NAME examples -) + ) + add_subdirectory(example) + add_subdirectory(test) -rocm_package_setup_component(profiler + rocm_package_setup_component(profiler LIBRARY_NAME composablekernel PACKAGE_NAME ckProfiler -) - -add_subdirectory(library) -add_subdirectory(example) -add_subdirectory(test) -add_subdirectory(profiler) + ) + add_subdirectory(profiler) + else() + #When building PROFILER_ONLY, label the package with GPU_ARCH + rocm_package_setup_component(profiler + LIBRARY_NAME composablekernel + PACKAGE_NAME ckProfiler_${GPU_ARCH} + ) + add_subdirectory(profiler) + endif() +endif() #Create an interface target for the include only files and call it "composablekernels" include(CMakePackageConfigHelpers) -set(version 1.0.0) write_basic_package_version_file( "${CMAKE_CURRENT_BINARY_DIR}/composable_kernelConfigVersion.cmake" VERSION "${version}" @@ -295,9 +437,9 @@ write_basic_package_version_file( ) configure_package_config_file(${CMAKE_CURRENT_SOURCE_DIR}/Config.cmake.in - "${CMAKE_CURRENT_BINARY_DIR}/composable_kernelConfig.cmake" - INSTALL_DESTINATION ${CMAKE_INSTALL_LIBDIR}/cmake/composable_kernel - NO_CHECK_REQUIRED_COMPONENTS_MACRO + "${CMAKE_CURRENT_BINARY_DIR}/composable_kernelConfig.cmake" + INSTALL_DESTINATION ${CMAKE_INSTALL_LIBDIR}/cmake/composable_kernel + NO_CHECK_REQUIRED_COMPONENTS_MACRO ) rocm_install(FILES @@ -306,6 +448,13 @@ rocm_install(FILES DESTINATION ${CMAKE_INSTALL_LIBDIR}/cmake/composable_kernel ) +# Install CK version and configuration files +rocm_install(FILES + ${PROJECT_BINARY_DIR}/include/ck/version.h + ${PROJECT_BINARY_DIR}/include/ck/config.h + DESTINATION ${CMAKE_INSTALL_INCLUDEDIR}/ck/ +) + set(CPACK_RESOURCE_FILE_LICENSE "${CMAKE_CURRENT_SOURCE_DIR}/LICENSE") set(CPACK_RPM_PACKAGE_LICENSE "MIT") diff --git a/CONTRIBUTORS.md b/CONTRIBUTORS.md index 8ccfe99c3cc73b643f8b92cb654005e54c0774bd..cdce5a46309f59b27d8c658e785f70bf743527db 100644 --- a/CONTRIBUTORS.md +++ b/CONTRIBUTORS.md @@ -4,11 +4,13 @@ This is the list of developers and contributors to Composable Kernel library ## Developers -[Chao Liu](https://github.com/asroy), [Jing Zhang](https://github.com/zjing14), 2018-2022 +[Chao Liu](https://github.com/asroy), [Jing Zhang](https://github.com/zjing14), 2018-2023 -[Letao Qin](https://github.com/ltqin), [Qianfeng Zhang](https://github.com/qianfengz), [Liang Huang](https://github.com/carlushuang), [Shaojie Wang](https://github.com/shaojiewang), 2019-2022 +[Letao Qin](https://github.com/ltqin), [Qianfeng Zhang](https://github.com/qianfengz), [Liang Huang](https://github.com/carlushuang), [Shaojie Wang](https://github.com/shaojiewang), 2019-2023 -[Anthony Chang](https://github.com/rosenrodt), [Chunyu Lai](https://github.com/rocking5566), [Illia Silin](https://github.com/illsilin), [Adam Osewski](https://github.com/aosewski), [Poyen Chen](https://github.com/poyenc), [Rosty Geyyer](https://github.com/geyyer), 2022 +[Anthony Chang](https://github.com/rosenrodt), [Chunyu Lai](https://github.com/rocking5566), [Illia Silin](https://github.com/illsilin), [Adam Osewski](https://github.com/aosewski), [Poyen Chen](https://github.com/poyenc), [Rosty Geyyer](https://github.com/geyyer), [Astha Rai](https://github.com/arai713), [Shi YanXing](https://github.com/Yanxing-Shi), 2022-2023 + +[Hari Sadasivan](https://github.com/hsadasiv), [Bartlomiej Kocot](https://github.com/bartekxk), [Bartlomiej Wroblewski](https://github.com/bwroblew), 2023 Hanwen Chang, 2019-2021, diff --git a/Dockerfile b/Dockerfile index 8e6ddb1eba3dd934222fa4bd571729db6762dc69..e479268f48050957fabdaca7cfc5d162610f4fc1 100644 --- a/Dockerfile +++ b/Dockerfile @@ -1,5 +1,5 @@ FROM ubuntu:20.04 - +ARG DEBIAN_FRONTEND=noninteractive ARG ROCMVERSION=5.6 ARG compiler_version="" ARG compiler_commit="" @@ -9,23 +9,38 @@ RUN set -xe ARG DEB_ROCM_REPO=http://repo.radeon.com/rocm/apt/.apt_$ROCMVERSION/ RUN useradd -rm -d /home/jenkins -s /bin/bash -u 1004 jenkins # Add rocm repository +RUN chmod 1777 /tmp RUN apt-get update -RUN apt-get install -y wget gnupg curl -RUN --mount=type=ssh if [ "$ROCMVERSION" != "5.6"]; then \ - wget -qO - http://repo.radeon.com/rocm/rocm.gpg.key | apt-key add - && \ - sh -c "echo deb [arch=amd64] $DEB_ROCM_REPO ubuntu main > /etc/apt/sources.list.d/rocm.list"; \ - else sh -c "wget http://artifactory-cdn.amd.com/artifactory/list/amdgpu-deb/amd-nonfree-radeon_20.04-1_all.deb" && \ - apt update && apt-get install -y ./amd-nonfree-radeon_20.04-1_all.deb && \ - amdgpu-repo --amdgpu-build=1567752 --rocm-build=compute-rocm-dkms-no-npi-hipclang/11914 && \ - DEBIAN_FRONTEND=noninteractive amdgpu-install -y --usecase=rocm ; \ +RUN apt-get install -y --allow-unauthenticated apt-utils wget gnupg2 curl + +ENV APT_KEY_DONT_WARN_ON_DANGEROUS_USAGE=DontWarn +RUN curl -fsSL https://repo.radeon.com/rocm/rocm.gpg.key | gpg --dearmor -o /etc/apt/trusted.gpg.d/rocm-keyring.gpg + +RUN wget https://repo.radeon.com/amdgpu-install/5.6/ubuntu/focal/amdgpu-install_5.6.50600-1_all.deb --no-check-certificate +RUN apt-get update && \ +DEBIAN_FRONTEND=noninteractive apt-get install -y --allow-unauthenticated \ + ./amdgpu-install_5.6.50600-1_all.deb + +RUN if [ "$ROCMVERSION" != "5.7" ]; then \ + wget -qO - http://repo.radeon.com/rocm/rocm.gpg.key | apt-key add - && \ + sh -c "echo deb [arch=amd64 signed-by=/etc/apt/trusted.gpg.d/rocm-keyring.gpg] $DEB_ROCM_REPO focal main > /etc/apt/sources.list.d/rocm.list" && \ + sh -c 'echo deb [arch=amd64 signed-by=/etc/apt/trusted.gpg.d/rocm-keyring.gpg] https://repo.radeon.com/amdgpu/$ROCMVERSION/ubuntu focal main > /etc/apt/sources.list.d/amdgpu.list'; \ + elif [ "$ROCMVERSION" = "5.7" ] && [ "$compiler_version" = "" ] || [ "$compiler_version" = "amd-stg-open" ]; then \ + sh -c "wget http://artifactory-cdn.amd.com/artifactory/list/amdgpu-deb/amdgpu-install-internal_5.7-20.04-1_all.deb" && \ + apt update && apt-get install -y ./amdgpu-install-internal_5.7-20.04-1_all.deb && \ + amdgpu-repo --amdgpu-build=1609671 --rocm-build=compute-rocm-npi-mi300/1354; \ + elif [ "$ROCMVERSION" = "5.7" ] && [ "$compiler_version" = "rc1" ]; then \ + sh -c "wget http://artifactory-cdn.amd.com/artifactory/list/amdgpu-deb/amdgpu-install-internal_5.7-20.04-1_all.deb" && \ + apt update && apt-get install -y ./amdgpu-install-internal_5.7-20.04-1_all.deb && \ + sh -c 'echo deb [arch=amd64 trusted=yes] http://compute-artifactory.amd.com/artifactory/list/rocm-release-archive-20.04-deb/ 5.7 rel-19 > /etc/apt/sources.list.d/rocm-build.list' && \ + amdgpu-repo --amdgpu-build=1637781; \ fi -RUN wget --no-check-certificate -qO - https://apt.kitware.com/keys/kitware-archive-latest.asc 2>/dev/null | apt-key add - + RUN sh -c "echo deb http://mirrors.kernel.org/ubuntu focal main universe | tee -a /etc/apt/sources.list" -RUN curl -fsSL https://repo.radeon.com/rocm/rocm.gpg.key | gpg --dearmor -o /etc/apt/trusted.gpg.d/rocm-keyring.gpg +RUN amdgpu-install -y --usecase=rocm --no-dkms # Install dependencies RUN apt-get update && DEBIAN_FRONTEND=noninteractive apt-get install -y --allow-unauthenticated \ - apt-utils \ build-essential \ ccache \ cmake \ @@ -40,24 +55,24 @@ RUN apt-get update && DEBIAN_FRONTEND=noninteractive apt-get install -y --allow- pkg-config \ python \ python3 \ - python-dev \ python3-dev \ python3-pip \ sshpass \ software-properties-common \ - rocm-dev \ - rocm-device-libs \ - rocm-cmake \ vim \ nano \ zlib1g-dev \ openssh-server \ - clang-format-10 \ + clang-format-12 \ kmod && \ apt-get clean && \ rm -rf /var/lib/apt/lists/* #Install latest version of cmake +RUN wget -qO /usr/local/bin/ninja.gz https://github.com/ninja-build/ninja/releases/latest/download/ninja-linux.zip +RUN gunzip /usr/local/bin/ninja.gz +RUN chmod a+x /usr/local/bin/ninja +RUN git clone https://github.com/nico/ninjatracing.git RUN apt purge --auto-remove -y cmake RUN apt update RUN apt install -y software-properties-common lsb-release @@ -103,7 +118,7 @@ ENV compiler_commit=$compiler_commit RUN sh -c "echo compiler version = '$compiler_version'" RUN sh -c "echo compiler commit = '$compiler_commit'" -RUN --mount=type=ssh if [ "$compiler_version" = "amd-stg-open" ] && [ "$compiler_commit" = "" ]; then \ +RUN if [ "$compiler_version" = "amd-stg-open" ] && [ "$compiler_commit" = "" ]; then \ git clone -b "$compiler_version" https://github.com/RadeonOpenCompute/llvm-project.git && \ cd llvm-project && mkdir build && cd build && \ cmake -DCMAKE_INSTALL_PREFIX=/opt/rocm/llvm -DCMAKE_BUILD_TYPE=Release -DLLVM_ENABLE_ASSERTIONS=1 -DLLVM_TARGETS_TO_BUILD="AMDGPU;X86" -DLLVM_ENABLE_PROJECTS="clang;lld;compiler-rt" ../llvm && \ @@ -111,7 +126,7 @@ RUN --mount=type=ssh if [ "$compiler_version" = "amd-stg-open" ] && [ "$compiler else echo "using the release compiler"; \ fi -RUN --mount=type=ssh if [ "$compiler_version" = "amd-stg-open" ] && [ "$compiler_commit" != "" ]; then \ +RUN if [ "$compiler_version" = "amd-stg-open" ] && [ "$compiler_commit" != "" ]; then \ git clone -b "$compiler_version" https://github.com/RadeonOpenCompute/llvm-project.git && \ cd llvm-project && git checkout "$compiler_commit" && echo "checking out commit $compiler_commit" && mkdir build && cd build && \ cmake -DCMAKE_INSTALL_PREFIX=/opt/rocm/llvm -DCMAKE_BUILD_TYPE=Release -DLLVM_ENABLE_ASSERTIONS=1 -DLLVM_TARGETS_TO_BUILD="AMDGPU;X86" -DLLVM_ENABLE_PROJECTS="clang;lld;compiler-rt" ../llvm && \ diff --git a/Jenkinsfile b/Jenkinsfile index fbff349fc3aaad004d81483fef6952d975660fdd..668d9a613bf3c0120076875b3882941530b66ae2 100644 --- a/Jenkinsfile +++ b/Jenkinsfile @@ -11,6 +11,20 @@ def show_node_info() { """ } +def nthreads() { + def nproc = sh(returnStdout: true, script: 'nproc') + echo "Number of cores: ${nproc}" + def n = nproc.toInteger() + if (n > 32){ + n /= 2 + } + if (n > 64){ + n = 64 + } + echo "Number of threads used for building: ${n}" + return n +} + def runShell(String command){ def responseCode = sh returnStatus: true, script: "${command} > tmp.txt" def output = readFile(file: "tmp.txt") @@ -19,7 +33,7 @@ def runShell(String command){ def getDockerImageName(){ def img - if (params.ROCMVERSION != "5.6"){ + if (params.ROCMVERSION != "5.7"){ if (params.COMPILER_VERSION == "") { img = "${env.CK_DOCKERHUB}:ck_ub20.04_rocm${params.ROCMVERSION}" } @@ -219,7 +233,8 @@ def cmake_build(Map conf=[:]){ """ def setup_cmd = conf.get("setup_cmd", "${cmake_envs} cmake ${setup_args} .. ") // reduce parallelism when compiling, clang uses too much memory - def build_cmd = conf.get("build_cmd", "${build_envs} dumb-init make -j\$(( \$(nproc) / 2 )) ${config_targets}") + def nt = nthreads() + def build_cmd = conf.get("build_cmd", "${build_envs} dumb-init make -j${nt} ${config_targets}") def execute_cmd = conf.get("execute_cmd", "") def cmd = conf.get("cmd", """ @@ -461,7 +476,7 @@ def Build_CK(Map conf=[:]){ else{ echo "GPU is OK" } - if ( runShell('grep -n "gfx1030" clinfo.log') ){ + if ( runShell('grep -n "gfx1030" clinfo.log') || runShell('grep -n "gfx1101" clinfo.log') ){ navi_node = 1 } } @@ -482,7 +497,7 @@ def Build_CK(Map conf=[:]){ else{ echo "GPU is OK" } - if ( runShell('grep -n "gfx1030" clinfo.log') ){ + if ( runShell('grep -n "gfx1030" clinfo.log') || runShell('grep -n "gfx1101" clinfo.log') ){ navi_node = 1 } } @@ -493,8 +508,8 @@ def Build_CK(Map conf=[:]){ { cmake_build(conf) dir("build"){ - //run tests and examples - sh 'make -j\$(( \$(nproc) / 2 )) check' + //run tests and examples + sh 'make -j check' if (navi_node == 0 ){ //we only need the ckProfiler to run the performance tests, so we pack and stash it //do not stash profiler on Navi nodes @@ -597,8 +612,8 @@ def process_results(Map conf=[:]){ } //launch develop branch daily at 23:00 UT in FULL_QA mode and at 19:00 UT with latest staging compiler version -CRON_SETTINGS = BRANCH_NAME == "develop" ? '''0 23 * * * % RUN_FULL_QA=true - 0 21 * * * % ROCMVERSION=5.5;COMPILER_VERSION=release;COMPILER_COMMIT= +CRON_SETTINGS = BRANCH_NAME == "develop" ? '''0 23 * * * % RUN_FULL_QA=true;ROCMVERSION=5.7;COMPILER_VERSION=rc1 + 0 21 * * * % ROCMVERSION=5.6;COMPILER_VERSION=;COMPILER_COMMIT= 0 19 * * * % BUILD_DOCKER=true;COMPILER_VERSION=amd-stg-open;COMPILER_COMMIT=''' : "" pipeline { @@ -674,7 +689,7 @@ pipeline { -o -iname \'*.cpp.in\' \ -o -iname \'*.cl\' \ | grep -v 'build/' \ - | xargs -n 1 -P 1 -I{} -t sh -c \'clang-format-10 -style=file {} | diff - {}\'" + | xargs -n 1 -P 1 -I{} -t sh -c \'clang-format-12 -style=file {} | diff - {}\'" } steps{ buildHipClangJobAndReboot(setup_cmd: "", build_cmd: "", execute_cmd: execute_cmd, no_reboot:true) @@ -695,8 +710,8 @@ pipeline { } agent{ label rocmnode("gfx908 || gfx90a") } environment{ - setup_args = """ -DCMAKE_INSTALL_PREFIX=../install -DGPU_TARGETS="gfx908;gfx90a;gfx940" """ - execute_args = """ cd ../client_example && rm -rf build && mkdir build && cd build && cmake -D CMAKE_PREFIX_PATH="${env.WORKSPACE}/install;/opt/rocm" -DGPU_TARGETS="gfx908;gfx90a;gfx940" -D CMAKE_CXX_COMPILER="${build_compiler()}" .. && make -j """ + setup_args = """ -DCMAKE_INSTALL_PREFIX=../install -DGPU_TARGETS="gfx908;gfx90a;gfx940;gfx941" """ + execute_args = """ cd ../client_example && rm -rf build && mkdir build && cd build && cmake -D CMAKE_PREFIX_PATH="${env.WORKSPACE}/install;/opt/rocm" -DGPU_TARGETS="gfx908;gfx90a;gfx940;gfx941" -D CMAKE_CXX_COMPILER="${build_compiler()}" .. && make -j """ } steps{ Build_CK_and_Reboot(setup_args: setup_args, config_targets: "install", no_reboot:true, build_type: 'Release', execute_cmd: execute_args, prefixpath: '/usr/local') @@ -717,7 +732,7 @@ pipeline { Build_CK_and_Reboot(setup_args: setup_args, config_targets: "install", no_reboot:true, build_type: 'Release', execute_cmd: execute_args, prefixpath: '/usr/local') } } - stage("Build CK and run Tests on Navi") + stage("Build CK and run Tests on Navi21") { when { beforeAgent true @@ -725,7 +740,7 @@ pipeline { } agent{ label rocmnode("navi21") } environment{ - setup_args = """ -DCMAKE_INSTALL_PREFIX=../install -DGPU_TARGETS="gfx1030" """ + setup_args = """ -DCMAKE_INSTALL_PREFIX=../install -DGPU_TARGETS="gfx1030" -DDL_KERNELS=ON """ execute_args = """ cd ../client_example && rm -rf build && mkdir build && cd build && cmake -D CMAKE_PREFIX_PATH="${env.WORKSPACE}/install;/opt/rocm" -DGPU_TARGETS="gfx1030" -D CMAKE_CXX_COMPILER="${build_compiler()}" .. && make -j """ } @@ -733,6 +748,22 @@ pipeline { Build_CK_and_Reboot(setup_args: setup_args, config_targets: "install", no_reboot:true, build_type: 'Release', execute_cmd: execute_args, prefixpath: '/usr/local') } } + stage("Build CK and run Tests on Navi32") + { + when { + beforeAgent true + expression { !params.RUN_FULL_QA.toBoolean() } + } + agent{ label rocmnode("navi32") } + environment{ + setup_args = """ -DCMAKE_INSTALL_PREFIX=../install -DDTYPES="fp16;fp32;bf16" -DGPU_TARGETS="gfx1101" """ + execute_args = """ cd ../client_example && rm -rf build && mkdir build && cd build && cmake -D CMAKE_PREFIX_PATH="${env.WORKSPACE}/install;/opt/rocm" -DGPU_TARGETS="gfx1101" -DDTYPES="fp16;fp32;bf16" -D CMAKE_CXX_COMPILER="${build_compiler()}" .. && make -j """ + + } + steps{ + Build_CK_and_Reboot(setup_args: setup_args, config_targets: "install", no_reboot:true, build_type: 'Release', execute_cmd: execute_args, prefixpath: '/usr/local') + } + } } } diff --git a/LICENSE b/LICENSE index 2fe9a8455efaeda2eab474b2aa038ec2d9e76841..e03fddaf78080705d26ec277629cfb8010c077bf 100644 --- a/LICENSE +++ b/LICENSE @@ -7,7 +7,7 @@ Copyright (c) 2020 , Advanced Micro Devices, Inc. (Xiaoyan Zhou) Copyright (c) 2021-2022, Advanced Micro Devices, Inc. (Jianfeng Yan) SPDX-License-Identifier: MIT -Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal diff --git a/README.md b/README.md index a45f61a37df2b82805a9a0643abc97821803ad5e..c2b493db11596d489c6bc5e4fcf9549080998c01 100644 --- a/README.md +++ b/README.md @@ -52,6 +52,8 @@ CK is released under the MIT license. [License File](/LICENSE) ```bash DOCKER_BUILDKIT=1 docker build -t ck:latest -f Dockerfile . ``` +Pre-built dockers are available from this public repo: +https://hub.docker.com/r/rocm/composable_kernel/tags ## Launch docker @@ -76,12 +78,26 @@ mkdir build && cd build cmake \ -D CMAKE_PREFIX_PATH=/opt/rocm \ -D CMAKE_CXX_COMPILER=/opt/rocm/bin/hipcc \ --D CMAKE_CXX_FLAGS="-O3" \ -D CMAKE_BUILD_TYPE=Release \ -D GPU_TARGETS="gfx908;gfx90a" \ .. ``` +If GPU_TARGETS is not set on the cmake command line, CK will be built for all targets supported by the +current compiler. + + +Additional cmake flags can be used to significantly speed-up the build: + +INSTANCES_ONLY (by default is OFF) must be set to ON in order to build only the instances and library +while skipping all tests, examples, and profiler. This is useful for libraries that use CK as a dependency. + +DTYPES (by default not set) can be set to any subset of "fp64;fp32;fp16;fp8;bf16;int8" to build instances +of select data types only. Currently, building of int8 instances is taking a lot of time (the compiler fix is in the works). + +DL_KERNELS (by default is OFF) must be set to ON in order to build the gemm_dl and batched_gemm_multi_d_dl +instances. Those instances are only needed for the NAVI2x platforms. + ### Build examples and tests ```bash @@ -109,6 +125,24 @@ make install Instructions for using CK as a pre-built kernel library are under [client_example](/client_example) +## Contributing + +When you contribute to Composable Kernel, make sure to run `clang-format` on all the changed files. We highly recommend using git hooks that are managed by the `pre-commit` framework. To install hooks, run: + +```bash +sudo script/install_precommit.sh +``` + +This way, `pre-commit` will add the appropriate hooks to your local repository and automatically run `clang-format` (and possibly additional checks) before any commit is created. + +If you need to uninstall hooks from the repository, you can do so by running the following command: + +```bash +script/uninstall_precommit.sh +``` + +If for any reason, you need to temporarily disable precommit hooks, you can add the `--no-verify` option to the `git commit` command. + ## Caveat ### Kernel Timing and Verification diff --git a/client_example/01_gemm/gemm.cpp b/client_example/01_gemm/gemm.cpp index ba7118ba3929e3e3bcdf02e40044748860bfeebe..c37f208db1cee9bfff1fff469fd79d059fb179f0 100644 --- a/client_example/01_gemm/gemm.cpp +++ b/client_example/01_gemm/gemm.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #include #include diff --git a/client_example/02_gemm_add_add_fastgelu/CMakeLists.txt b/client_example/02_gemm_add_add_fastgelu/CMakeLists.txt index b7b724ccc484d6c9ea83e791aa5b8fcde420eef7..ba2952022233b3ffd21fc245b8fd199a5ec5b096 100644 --- a/client_example/02_gemm_add_add_fastgelu/CMakeLists.txt +++ b/client_example/02_gemm_add_add_fastgelu/CMakeLists.txt @@ -11,3 +11,17 @@ target_link_libraries(client_gemm_fastgelu PRIVATE composable_kernel::device_ope add_dependencies(client_gemm_fastgelu_examples client_gemm_add_add_fastgelu client_gemm_add_fastgelu client_gemm_fastgelu) + +add_custom_target(client_gemm_fastgelu_generic_examples) + +add_executable(client_gemm_add_add_fastgelu_generic gemm_add_add_fastgelu_generic.cpp) +target_link_libraries(client_gemm_add_add_fastgelu_generic PRIVATE composable_kernel::device_operations) + +add_executable(client_gemm_add_fastgelu_generic gemm_add_fastgelu_generic.cpp) +target_link_libraries(client_gemm_add_fastgelu_generic PRIVATE composable_kernel::device_operations) + +add_executable(client_gemm_fastgelu_generic gemm_fastgelu_generic.cpp) +target_link_libraries(client_gemm_fastgelu_generic PRIVATE composable_kernel::device_operations) + +add_dependencies(client_gemm_fastgelu_generic_examples client_gemm_add_add_fastgelu_generic + client_gemm_add_fastgelu_generic client_gemm_fastgelu_generic) diff --git a/client_example/02_gemm_add_add_fastgelu/gemm_add_add_fastgelu.cpp b/client_example/02_gemm_add_add_fastgelu/gemm_add_add_fastgelu.cpp index 08f297f58a8d522aec8c991a8d38484ccf8e420c..756889562e84c66efb5f972621bcb61edda3af82 100644 --- a/client_example/02_gemm_add_add_fastgelu/gemm_add_add_fastgelu.cpp +++ b/client_example/02_gemm_add_add_fastgelu/gemm_add_add_fastgelu.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #include #include diff --git a/client_example/02_gemm_add_add_fastgelu/gemm_add_add_fastgelu_generic.cpp b/client_example/02_gemm_add_add_fastgelu/gemm_add_add_fastgelu_generic.cpp new file mode 100644 index 0000000000000000000000000000000000000000..2ed942f0adf64df90ea1046d1de191ce2247d7c4 --- /dev/null +++ b/client_example/02_gemm_add_add_fastgelu/gemm_add_add_fastgelu_generic.cpp @@ -0,0 +1,176 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#include +#include +#include +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/device_gemm_multiple_d.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" + +#include "ck/library/tensor_operation_instance/gpu/gemm_add_add_fastgelu.hpp" + +using F16 = ck::half_t; +using F32 = float; + +using Row = ck::tensor_layout::gemm::RowMajor; +using Col = ck::tensor_layout::gemm::ColumnMajor; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; +using AddAddFastGelu = ck::tensor_operation::element_wise::AddAddFastGelu; + +using AElementOp = PassThrough; +using BElementOp = PassThrough; +using CDEElementOp = AddAddFastGelu; + +using ADataType = F16; +using BDataType = F16; +using D0DataType = F16; +using D1DataType = F16; +using EDataType = F16; + +using ALayout = Row; +using BLayout = Col; +using D0Layout = Row; +using D1Layout = Row; +using ELayout = Row; + +struct SimpleDeviceMem +{ + SimpleDeviceMem() = delete; + + SimpleDeviceMem(std::size_t mem_size) : p_mem_{} + { + (void)hipMalloc(static_cast(&p_mem_), mem_size); + } + + void* GetDeviceBuffer() { return p_mem_; } + + ~SimpleDeviceMem() { (void)hipFree(p_mem_); } + + void* p_mem_; +}; + +int main(int argc, char* argv[]) +{ + // GEMM shape + ck::index_t M = 3840; + ck::index_t N = 4096; + ck::index_t K = 4096; + + ck::index_t StrideA = 4096; + ck::index_t StrideB = 4096; + ck::index_t StrideD0 = 0; + ck::index_t StrideD1 = 4096; + ck::index_t StrideE = 4096; + + if(argc == 1) + { + // use default case + } + else if(argc == 9) + { + M = std::stoi(argv[1]); + N = std::stoi(argv[2]); + K = std::stoi(argv[3]); + + StrideA = std::stoi(argv[4]); + StrideB = std::stoi(argv[5]); + StrideD0 = std::stoi(argv[6]); + StrideD1 = std::stoi(argv[7]); + StrideE = std::stoi(argv[8]); + } + else + { + printf("arg1 to 8: M, N, K, StrideA, StrideB, StrideD0, StrideD1, StrideE\n"); + exit(0); + } + + auto f_matrix_space_size = + [](std::size_t nRow, std::size_t nCol, std::size_t stride, auto layout) { + using Layout = decltype(layout); + + if constexpr(std::is_same::value) + { + return (nRow - 1) * stride + nCol; + } + else + { + return (nCol - 1) * stride + nRow; + } + }; + + SimpleDeviceMem a_device_buf(sizeof(ADataType) * f_matrix_space_size(M, K, StrideA, ALayout{})); + SimpleDeviceMem b_device_buf(sizeof(BDataType) * f_matrix_space_size(K, N, StrideB, BLayout{})); + SimpleDeviceMem d0_m_n_device_buf(sizeof(D0DataType) * + f_matrix_space_size(M, N, StrideD0, D0Layout{})); + SimpleDeviceMem d1_m_n_device_buf(sizeof(D1DataType) * + f_matrix_space_size(M, N, StrideD1, D1Layout{})); + SimpleDeviceMem e_device_buf(sizeof(EDataType) * f_matrix_space_size(M, N, StrideE, ELayout{})); + + using DeviceOp = ck::tensor_operation::device::DeviceGemmMultipleD< + ALayout, + BLayout, + ck::Tuple, + ELayout, + ADataType, + BDataType, + ck::Tuple, + EDataType, + ck::tensor_operation::element_wise::PassThrough, + ck::tensor_operation::element_wise::PassThrough, + ck::tensor_operation::element_wise::AddAddFastGelu>; + + // get device op instances + const auto op_ptrs = ck::tensor_operation::device::instance::DeviceOperationInstanceFactory< + DeviceOp>::GetInstances(); + + std::cout << "found " << op_ptrs.size() << " instances" << std::endl; + + const auto a_element_op = AElementOp{}; + const auto b_element_op = BElementOp{}; + const auto cde_element_op = CDEElementOp{}; + + // get generic instance + auto& op_ptr = op_ptrs[0]; + + std::cout << "Run the generic instance without timing: " << op_ptr->GetTypeString() + << std::endl; + + // run the generic instance + auto argument_ptr = + op_ptr->MakeArgumentPointer(a_device_buf.GetDeviceBuffer(), + b_device_buf.GetDeviceBuffer(), + std::array{d0_m_n_device_buf.GetDeviceBuffer(), + d1_m_n_device_buf.GetDeviceBuffer()}, + e_device_buf.GetDeviceBuffer(), + M, + N, + K, + StrideA, + StrideB, + std::array{StrideD0, StrideD1}, + StrideE, + a_element_op, + b_element_op, + cde_element_op); + + auto invoker_ptr = op_ptr->MakeInvokerPointer(); + + if(op_ptr->IsSupportedArgument(argument_ptr.get())) + { + invoker_ptr->Run(argument_ptr.get(), StreamConfig{nullptr, false}); + } + else + { + throw std::runtime_error( + "Generic instance should be suitable for various input lengths/strides"); + } + + std::cout << "Done" << std::endl; + + return 0; +} diff --git a/client_example/02_gemm_add_add_fastgelu/gemm_add_fastgelu.cpp b/client_example/02_gemm_add_add_fastgelu/gemm_add_fastgelu.cpp index 658c1e9e8fcbeab1ec7be4115d5a7d62c5e77283..8d2a8c234aae63a3566478b5aa9588389a247d4e 100644 --- a/client_example/02_gemm_add_add_fastgelu/gemm_add_fastgelu.cpp +++ b/client_example/02_gemm_add_add_fastgelu/gemm_add_fastgelu.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #include #include @@ -76,7 +76,7 @@ int main(int argc, char* argv[]) StrideA = std::stoi(argv[4]); StrideB = std::stoi(argv[5]); StrideD0 = std::stoi(argv[6]); - StrideE = std::stoi(argv[8]); + StrideE = std::stoi(argv[7]); } else { diff --git a/client_example/02_gemm_add_add_fastgelu/gemm_add_fastgelu_generic.cpp b/client_example/02_gemm_add_add_fastgelu/gemm_add_fastgelu_generic.cpp new file mode 100644 index 0000000000000000000000000000000000000000..644b428fc9f51b28a0f81ed7d79ac741ca6fbcbe --- /dev/null +++ b/client_example/02_gemm_add_add_fastgelu/gemm_add_fastgelu_generic.cpp @@ -0,0 +1,169 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#include +#include +#include +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/device_gemm_multiple_d.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" + +#include "ck/library/tensor_operation_instance/gpu/gemm_add_fastgelu.hpp" + +using F16 = ck::half_t; +using F32 = float; + +using Row = ck::tensor_layout::gemm::RowMajor; +using Col = ck::tensor_layout::gemm::ColumnMajor; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; +using AddFastGelu = ck::tensor_operation::element_wise::AddFastGelu; + +using AElementOp = PassThrough; +using BElementOp = PassThrough; +using CDEElementOp = AddFastGelu; + +using ADataType = F16; +using BDataType = F16; +using D0DataType = F16; +using EDataType = F16; + +using ALayout = Row; +using BLayout = Col; +using D0Layout = Row; +using ELayout = Row; + +struct SimpleDeviceMem +{ + SimpleDeviceMem() = delete; + + SimpleDeviceMem(std::size_t mem_size) : p_mem_{} + { + (void)hipMalloc(static_cast(&p_mem_), mem_size); + } + + void* GetDeviceBuffer() { return p_mem_; } + + ~SimpleDeviceMem() { (void)hipFree(p_mem_); } + + void* p_mem_; +}; + +int main(int argc, char* argv[]) +{ + // GEMM shape + ck::index_t M = 3840; + ck::index_t N = 4096; + ck::index_t K = 4096; + + ck::index_t StrideA = 4096; + ck::index_t StrideB = 4096; + ck::index_t StrideD0 = 0; + ck::index_t StrideE = 4096; + + if(argc == 1) + { + // use default case + } + else if(argc == 8) + { + M = std::stoi(argv[1]); + N = std::stoi(argv[2]); + K = std::stoi(argv[3]); + + StrideA = std::stoi(argv[4]); + StrideB = std::stoi(argv[5]); + StrideD0 = std::stoi(argv[6]); + StrideE = std::stoi(argv[7]); + } + else + { + printf("arg1 to 7: M, N, K, StrideA, StrideB, StrideD0, StrideE\n"); + exit(0); + } + + auto f_matrix_space_size = + [](std::size_t nRow, std::size_t nCol, std::size_t stride, auto layout) { + using Layout = decltype(layout); + + if constexpr(std::is_same::value) + { + return (nRow - 1) * stride + nCol; + } + else + { + return (nCol - 1) * stride + nRow; + } + }; + + SimpleDeviceMem a_device_buf(sizeof(ADataType) * f_matrix_space_size(M, K, StrideA, ALayout{})); + SimpleDeviceMem b_device_buf(sizeof(BDataType) * f_matrix_space_size(K, N, StrideB, BLayout{})); + SimpleDeviceMem d0_m_n_device_buf(sizeof(D0DataType) * + f_matrix_space_size(M, N, StrideD0, D0Layout{})); + SimpleDeviceMem e_device_buf(sizeof(EDataType) * f_matrix_space_size(M, N, StrideE, ELayout{})); + + using DeviceOp = ck::tensor_operation::device::DeviceGemmMultipleD< + ALayout, + BLayout, + ck::Tuple, + ELayout, + ADataType, + BDataType, + ck::Tuple, + EDataType, + ck::tensor_operation::element_wise::PassThrough, + ck::tensor_operation::element_wise::PassThrough, + ck::tensor_operation::element_wise::AddFastGelu>; + + // get device op instances + const auto op_ptrs = ck::tensor_operation::device::instance::DeviceOperationInstanceFactory< + DeviceOp>::GetInstances(); + + std::cout << "found " << op_ptrs.size() << " instances" << std::endl; + + const auto a_element_op = AElementOp{}; + const auto b_element_op = BElementOp{}; + const auto cde_element_op = CDEElementOp{}; + + // get generic instance + auto& op_ptr = op_ptrs[0]; + + std::cout << "Run the generic instance without timing: " << op_ptr->GetTypeString() + << std::endl; + + // run the generic instance + auto argument_ptr = + op_ptr->MakeArgumentPointer(a_device_buf.GetDeviceBuffer(), + b_device_buf.GetDeviceBuffer(), + std::array{d0_m_n_device_buf.GetDeviceBuffer()}, + e_device_buf.GetDeviceBuffer(), + M, + N, + K, + StrideA, + StrideB, + std::array{StrideD0}, + StrideE, + a_element_op, + b_element_op, + cde_element_op); + + auto invoker_ptr = op_ptr->MakeInvokerPointer(); + + if(op_ptr->IsSupportedArgument(argument_ptr.get())) + { + invoker_ptr->Run(argument_ptr.get(), StreamConfig{nullptr, false}); + } + else + { + throw std::runtime_error( + "Generic instance should be suitable for various input lengths/strides"); + } + + std::cout << "Done" << std::endl; + + return 0; +} diff --git a/client_example/02_gemm_add_add_fastgelu/gemm_fastgelu.cpp b/client_example/02_gemm_add_add_fastgelu/gemm_fastgelu.cpp index ea269545a5cc576f7cf98ea8496acf713b6aaea6..c02df018fd35c6d37a2c6f9b9fded6390c9afb19 100644 --- a/client_example/02_gemm_add_add_fastgelu/gemm_fastgelu.cpp +++ b/client_example/02_gemm_add_add_fastgelu/gemm_fastgelu.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #include #include @@ -72,7 +72,7 @@ int main(int argc, char* argv[]) StrideA = std::stoi(argv[4]); StrideB = std::stoi(argv[5]); - StrideE = std::stoi(argv[8]); + StrideE = std::stoi(argv[6]); } else { diff --git a/client_example/02_gemm_add_add_fastgelu/gemm_fastgelu_generic.cpp b/client_example/02_gemm_add_add_fastgelu/gemm_fastgelu_generic.cpp new file mode 100644 index 0000000000000000000000000000000000000000..482e93b421f7700cdf37ee014cf407ca3c63555b --- /dev/null +++ b/client_example/02_gemm_add_add_fastgelu/gemm_fastgelu_generic.cpp @@ -0,0 +1,162 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#include +#include +#include +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/device_gemm_multiple_d.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" + +#include "ck/library/tensor_operation_instance/gpu/gemm_fastgelu.hpp" + +using F16 = ck::half_t; +using F32 = float; + +using Row = ck::tensor_layout::gemm::RowMajor; +using Col = ck::tensor_layout::gemm::ColumnMajor; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; +using FastGelu = ck::tensor_operation::element_wise::FastGelu; + +using AElementOp = PassThrough; +using BElementOp = PassThrough; +using CDEElementOp = FastGelu; + +using ADataType = F16; +using BDataType = F16; +using EDataType = F16; + +using ALayout = Row; +using BLayout = Col; +using ELayout = Row; + +struct SimpleDeviceMem +{ + SimpleDeviceMem() = delete; + + SimpleDeviceMem(std::size_t mem_size) : p_mem_{} + { + (void)hipMalloc(static_cast(&p_mem_), mem_size); + } + + void* GetDeviceBuffer() { return p_mem_; } + + ~SimpleDeviceMem() { (void)hipFree(p_mem_); } + + void* p_mem_; +}; + +int main(int argc, char* argv[]) +{ + // GEMM shape + ck::index_t M = 3840; + ck::index_t N = 4096; + ck::index_t K = 4096; + + ck::index_t StrideA = 4096; + ck::index_t StrideB = 4096; + ck::index_t StrideE = 4096; + + if(argc == 1) + { + // use default case + } + else if(argc == 7) + { + M = std::stoi(argv[1]); + N = std::stoi(argv[2]); + K = std::stoi(argv[3]); + + StrideA = std::stoi(argv[4]); + StrideB = std::stoi(argv[5]); + StrideE = std::stoi(argv[6]); + } + else + { + printf("arg1 to 6: M, N, K, StrideA, StrideB, StrideE\n"); + exit(0); + } + + auto f_matrix_space_size = + [](std::size_t nRow, std::size_t nCol, std::size_t stride, auto layout) { + using Layout = decltype(layout); + + if constexpr(std::is_same::value) + { + return (nRow - 1) * stride + nCol; + } + else + { + return (nCol - 1) * stride + nRow; + } + }; + + SimpleDeviceMem a_device_buf(sizeof(ADataType) * f_matrix_space_size(M, K, StrideA, ALayout{})); + SimpleDeviceMem b_device_buf(sizeof(BDataType) * f_matrix_space_size(K, N, StrideB, BLayout{})); + SimpleDeviceMem e_device_buf(sizeof(EDataType) * f_matrix_space_size(M, N, StrideE, ELayout{})); + + using DeviceOp = ck::tensor_operation::device::DeviceGemmMultipleD< + ALayout, + BLayout, + ck::Tuple<>, + ELayout, + ADataType, + BDataType, + ck::Tuple<>, + EDataType, + ck::tensor_operation::element_wise::PassThrough, + ck::tensor_operation::element_wise::PassThrough, + ck::tensor_operation::element_wise::FastGelu>; + + // get device op instances + const auto op_ptrs = ck::tensor_operation::device::instance::DeviceOperationInstanceFactory< + DeviceOp>::GetInstances(); + + std::cout << "found " << op_ptrs.size() << " instances" << std::endl; + + const auto a_element_op = AElementOp{}; + const auto b_element_op = BElementOp{}; + const auto cde_element_op = CDEElementOp{}; + + // get generic instance + auto& op_ptr = op_ptrs[0]; + + std::cout << "Run the generic instance without timing: " << op_ptr->GetTypeString() + << std::endl; + + // run the generic instance + auto argument_ptr = op_ptr->MakeArgumentPointer(a_device_buf.GetDeviceBuffer(), + b_device_buf.GetDeviceBuffer(), + {}, + e_device_buf.GetDeviceBuffer(), + M, + N, + K, + StrideA, + StrideB, + {}, + StrideE, + a_element_op, + b_element_op, + cde_element_op); + + auto invoker_ptr = op_ptr->MakeInvokerPointer(); + + if(op_ptr->IsSupportedArgument(argument_ptr.get())) + { + invoker_ptr->Run(argument_ptr.get(), StreamConfig{nullptr, false}); + } + else + { + throw std::runtime_error( + "Generic instance should be suitable for various input lengths/strides"); + } + + std::cout << "Done" << std::endl; + + return 0; +} diff --git a/client_example/03_gemm_layernorm/gemm_add_add_layernorm_naive.cpp b/client_example/03_gemm_layernorm/gemm_add_add_layernorm_naive.cpp index caa6573788d201b9fda605e5388423d5966d41ec..58c91f903bc7e2f3f296b06c746b1629513bc7f2 100644 --- a/client_example/03_gemm_layernorm/gemm_add_add_layernorm_naive.cpp +++ b/client_example/03_gemm_layernorm/gemm_add_add_layernorm_naive.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #include #include @@ -172,18 +172,19 @@ int main() BLayout, CLayout>(); - const auto normalize_ptrs = - ck::tensor_operation::device::instance::get_device_normalize_from_mean_meansquare_instances< - CDataType, - ReduceDataType, - ReduceDataType, - GammaDataType, - BetaDataType, - LayerNormOutDataType>(); - std::cout << "found " << gemm_reduce_ptrs.size() << " gemm_reduceMean_reduceSquareMean instances" << std::endl; + using NormalizeDeviceOp = ck::tensor_operation::device::DeviceElementwise< + ck::Tuple, + ck::Tuple, + ck::tensor_operation::element_wise::Normalize, + 2>; + + const auto normalize_ptrs = + ck::tensor_operation::device::instance::DeviceOperationInstanceFactory< + NormalizeDeviceOp>::GetInstances(); + std::cout << "found " << normalize_ptrs.size() << " normalize instances" << std::endl; auto f_matrix_space_size = diff --git a/client_example/03_gemm_layernorm/gemm_add_relu_add_layernorm_welford.cpp b/client_example/03_gemm_layernorm/gemm_add_relu_add_layernorm_welford.cpp index d4f0c2048ba81c3a1f98a22c5eb57654b0498806..3d5fb6004844af269f7786ae05de8c20cc720633 100644 --- a/client_example/03_gemm_layernorm/gemm_add_relu_add_layernorm_welford.cpp +++ b/client_example/03_gemm_layernorm/gemm_add_relu_add_layernorm_welford.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #include #include diff --git a/client_example/04_contraction/contraction_bilinear_fp32.cpp b/client_example/04_contraction/contraction_bilinear_fp32.cpp index 91dead41a4cac19db857b99a233839e9e6647c57..89f834b9824e134f8f0aeed8aa54f78a5c8824a3 100644 --- a/client_example/04_contraction/contraction_bilinear_fp32.cpp +++ b/client_example/04_contraction/contraction_bilinear_fp32.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #include #include diff --git a/client_example/04_contraction/contraction_bilinear_fp64.cpp b/client_example/04_contraction/contraction_bilinear_fp64.cpp index 9238e4cd80075a2caf47d8757a24f7cf82c4b8bd..1aa3ba7de597a0b97a295b0f7ee7ad21b1e9cd80 100644 --- a/client_example/04_contraction/contraction_bilinear_fp64.cpp +++ b/client_example/04_contraction/contraction_bilinear_fp64.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #include #include diff --git a/client_example/04_contraction/contraction_g1m2n3k1_add_xdl_fp16.cpp b/client_example/04_contraction/contraction_g1m2n3k1_add_xdl_fp16.cpp index 62be3377a2fb18bf388b857ecb758c0b7987871c..f8ea2258c2ba262e9db42f1b4dc92ff16cdc6286 100644 --- a/client_example/04_contraction/contraction_g1m2n3k1_add_xdl_fp16.cpp +++ b/client_example/04_contraction/contraction_g1m2n3k1_add_xdl_fp16.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #include #include diff --git a/client_example/04_contraction/contraction_scale_fp32.cpp b/client_example/04_contraction/contraction_scale_fp32.cpp index 4e08ee19cdb098b2dfb70a662d59c87008400123..ba7b0633c33aabeeb06547edfb08f506e637e599 100644 --- a/client_example/04_contraction/contraction_scale_fp32.cpp +++ b/client_example/04_contraction/contraction_scale_fp32.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #include #include diff --git a/client_example/04_contraction/contraction_scale_fp64.cpp b/client_example/04_contraction/contraction_scale_fp64.cpp index 3c36aa21eb6c34df75f765cce894d4f137d4f080..24e52eb5aa423339ff96ad0914dc479d715fe7b7 100644 --- a/client_example/04_contraction/contraction_scale_fp64.cpp +++ b/client_example/04_contraction/contraction_scale_fp64.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #include #include diff --git a/client_example/05_layernorm/layernorm2d.cpp b/client_example/05_layernorm/layernorm2d.cpp index 856a4cc21935f094bfd7040ea095fb04d78b7eb4..3ee7cead7b8a9f5069922c2c61921b640cf65ed7 100644 --- a/client_example/05_layernorm/layernorm2d.cpp +++ b/client_example/05_layernorm/layernorm2d.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #include #include @@ -100,6 +100,10 @@ int main(int argc, char* argv[]) if(op_ptr->IsSupportedArgument(argument_ptr.get())) { + size_t workspace_sz = op_ptr->GetWorkSpaceSize(argument_ptr.get()); + SimpleDeviceMem workspace(workspace_sz); + op_ptr->SetWorkSpacePointer(argument_ptr.get(), workspace.GetDeviceBuffer()); + float ave_time = invoker_ptr->Run(argument_ptr.get(), StreamConfig{nullptr, true}); std::size_t num_byte = sizeof(XDataType) * M * N + sizeof(GammaDataType) * N + @@ -153,6 +157,10 @@ int main(int argc, char* argv[]) if(op_ptr->IsSupportedArgument(argument_ptr.get())) { + size_t workspace_sz = op_ptr->GetWorkSpaceSize(argument_ptr.get()); + SimpleDeviceMem workspace(workspace_sz); + op_ptr->SetWorkSpacePointer(argument_ptr.get(), workspace.GetDeviceBuffer()); + invoker_ptr->Run(argument_ptr.get(), StreamConfig{nullptr, false}); } diff --git a/client_example/06_softmax/softmax4d.cpp b/client_example/06_softmax/softmax4d.cpp index e939ce8dfedb10166b15388cae1d659d33e2154a..2ccad27a88757b576a050d0558acc4108857cbd1 100644 --- a/client_example/06_softmax/softmax4d.cpp +++ b/client_example/06_softmax/softmax4d.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #include #include @@ -53,12 +53,35 @@ int main(int argc, char* argv[]) SimpleDeviceMem in(sizeof(InDataType) * num_elements); SimpleDeviceMem out(sizeof(OutDataType) * num_elements); - using DeviceOp = ck::tensor_operation::device:: - DeviceSoftmax; + using DeviceOp = ck::tensor_operation::device::DeviceSoftmax; // get device op instances const auto op_ptrs = ck::tensor_operation::device::instance::DeviceOperationInstanceFactory< DeviceOp>::GetInstances(); + auto& generic_op_ptr = op_ptrs[0]; + + auto generic_argument_ptr = generic_op_ptr->MakeArgumentPointer(in_lengths, + in_strides, + reduce_dims, + alpha, + beta, + in.GetDeviceBuffer(), + out.GetDeviceBuffer(), + PassThrough{}, + PassThrough{}); + + if(!generic_op_ptr->IsSupportedArgument(generic_argument_ptr.get())) + { + throw std::runtime_error( + "The generic kernel instance should be able to support any input shapes"); + }; + std::cout << "found " << op_ptrs.size() << " instances" << std::endl; std::string best_op_name; @@ -74,11 +97,6 @@ int main(int argc, char* argv[]) { auto& op_ptr = op_ptrs[i]; - if(op_ptr->GetRank() != Rank || op_ptr->GetNumReduceDim() != NumReduceDim) - { - continue; - } - auto argument_ptr = op_ptr->MakeArgumentPointer(in_lengths, in_strides, reduce_dims, diff --git a/client_example/07_grouped_convnd_fwd/grouped_conv1d_fwd.cpp b/client_example/07_grouped_convnd_fwd/grouped_conv1d_fwd.cpp index 9fbdb83b1cf10082958730e5115f3bf54e12d415..70be0101c6d92512ab28a72797d2b8c46fb55281 100644 --- a/client_example/07_grouped_convnd_fwd/grouped_conv1d_fwd.cpp +++ b/client_example/07_grouped_convnd_fwd/grouped_conv1d_fwd.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #include #include diff --git a/client_example/07_grouped_convnd_fwd/grouped_conv2d_fwd.cpp b/client_example/07_grouped_convnd_fwd/grouped_conv2d_fwd.cpp index 0a798be270fe4148f904f35df3550be552e294ff..57a210fa1f5f0d3f0eca014e08ccd119d29e4fd6 100644 --- a/client_example/07_grouped_convnd_fwd/grouped_conv2d_fwd.cpp +++ b/client_example/07_grouped_convnd_fwd/grouped_conv2d_fwd.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #include #include diff --git a/client_example/08_fused_attention/fused_attention.cpp b/client_example/08_fused_attention/fused_attention.cpp index fe927da1248786a4b943f610ce38b75f0d88defd..df6bc11a70d32df221612466a8af0fbcd9cafb1c 100644 --- a/client_example/08_fused_attention/fused_attention.cpp +++ b/client_example/08_fused_attention/fused_attention.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #include #include diff --git a/client_example/08_fused_attention/fused_attention_bias.cpp b/client_example/08_fused_attention/fused_attention_bias.cpp index 3113b7856025af74c610981db130a8d965c36a24..6c9f3bc8f6f5a3c06f339f1246b5b3985e11d2d8 100644 --- a/client_example/08_fused_attention/fused_attention_bias.cpp +++ b/client_example/08_fused_attention/fused_attention_bias.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #include #include diff --git a/client_example/09_quantization/CMakeLists.txt b/client_example/09_quantization/CMakeLists.txt index 2b7d6fc806ad48f650c3e1f72ef3190cb9f342f4..ac11aad45de84b8a2e042d8a1a4bb059b889f3f2 100644 --- a/client_example/09_quantization/CMakeLists.txt +++ b/client_example/09_quantization/CMakeLists.txt @@ -1,3 +1,4 @@ +if(DTYPES MATCHES "int8" OR NOT DEFINED DTYPES) add_executable(client_conv2d_fwd_bias_tanh_perchannel_quantization conv2d_fwd_bias_tanh_perchannel_quantization.cpp) target_link_libraries(client_conv2d_fwd_bias_tanh_perchannel_quantization PRIVATE composable_kernel::device_operations) @@ -18,3 +19,4 @@ target_link_libraries(client_conv2d_fwd_perlayer_quantization PRIVATE composable add_executable(client_gemm_quantization gemm_quantization.cpp) target_link_libraries(client_gemm_quantization PRIVATE composable_kernel::device_operations) +endif() diff --git a/client_example/09_quantization/conv2d_fwd_bias_relu_perchannel_quantization.cpp b/client_example/09_quantization/conv2d_fwd_bias_relu_perchannel_quantization.cpp index 43a4779f5ff0db4558b384ced53f3bfa78ca9c4b..cd504e942e943b8f174442554eca379cd908ba6e 100644 --- a/client_example/09_quantization/conv2d_fwd_bias_relu_perchannel_quantization.cpp +++ b/client_example/09_quantization/conv2d_fwd_bias_relu_perchannel_quantization.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #include #include diff --git a/client_example/09_quantization/conv2d_fwd_bias_relu_perlayer_quantization.cpp b/client_example/09_quantization/conv2d_fwd_bias_relu_perlayer_quantization.cpp index 2ff91fe966ed39181b3611e2a5cf1840fc273121..f4aa3666b1c15eccd37bdef549f8402bcd1b252b 100644 --- a/client_example/09_quantization/conv2d_fwd_bias_relu_perlayer_quantization.cpp +++ b/client_example/09_quantization/conv2d_fwd_bias_relu_perlayer_quantization.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #include #include diff --git a/client_example/09_quantization/conv2d_fwd_bias_tanh_perchannel_quantization.cpp b/client_example/09_quantization/conv2d_fwd_bias_tanh_perchannel_quantization.cpp index 6ea5dd2231b905af6572ec7f7f5cc07ff3edb27e..ebdbbf52c0ca257dd8f672a48b32f80fa5bb8816 100644 --- a/client_example/09_quantization/conv2d_fwd_bias_tanh_perchannel_quantization.cpp +++ b/client_example/09_quantization/conv2d_fwd_bias_tanh_perchannel_quantization.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #include #include diff --git a/client_example/09_quantization/conv2d_fwd_bias_tanh_perlayer_quantization.cpp b/client_example/09_quantization/conv2d_fwd_bias_tanh_perlayer_quantization.cpp index 33407c9a1cef1f56591c11a135e7be816cf18d6c..9d60baee06fddb27451f1f51642bad738739c4bb 100644 --- a/client_example/09_quantization/conv2d_fwd_bias_tanh_perlayer_quantization.cpp +++ b/client_example/09_quantization/conv2d_fwd_bias_tanh_perlayer_quantization.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #include #include diff --git a/client_example/09_quantization/conv2d_fwd_perchannel_quantization.cpp b/client_example/09_quantization/conv2d_fwd_perchannel_quantization.cpp index 6f51124707f15fedef12797d96085f2659612ce4..dd81d9ee6b6dc0ee453cac61d7e795c79ba3c9fa 100644 --- a/client_example/09_quantization/conv2d_fwd_perchannel_quantization.cpp +++ b/client_example/09_quantization/conv2d_fwd_perchannel_quantization.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #include #include diff --git a/client_example/09_quantization/conv2d_fwd_perlayer_quantization.cpp b/client_example/09_quantization/conv2d_fwd_perlayer_quantization.cpp index 6a11f9fc26a7e6fd0f2fc7022058c4a72f6a6610..9c088a21d38e7f62c8eee35bf4c557a0bb2209bd 100644 --- a/client_example/09_quantization/conv2d_fwd_perlayer_quantization.cpp +++ b/client_example/09_quantization/conv2d_fwd_perlayer_quantization.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #include #include diff --git a/client_example/09_quantization/gemm_quantization.cpp b/client_example/09_quantization/gemm_quantization.cpp index 242504b44ff942dcefdbf6cf7ffdc7d6b5b2bbdb..b14e68fa082f8f9d05ab5f00471a2fa82d1d113b 100644 --- a/client_example/09_quantization/gemm_quantization.cpp +++ b/client_example/09_quantization/gemm_quantization.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #include #include diff --git a/client_example/10_grouped_conv2d_bwd_data/grouped_conv2d_bwd_data.cpp b/client_example/10_grouped_conv2d_bwd_data/grouped_conv2d_bwd_data.cpp index 55c789804230ccccf66d68be9244c5c4111451e6..1b2e8abc201c2aed2cd2eebccb68405a25033a43 100644 --- a/client_example/10_grouped_conv2d_bwd_data/grouped_conv2d_bwd_data.cpp +++ b/client_example/10_grouped_conv2d_bwd_data/grouped_conv2d_bwd_data.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #include #include diff --git a/client_example/11_grouped_conv_bwd_weight/common.hpp b/client_example/11_grouped_conv_bwd_weight/common.hpp index a906263333c8a15947e12fe8702e431d4acb7999..4292cded2071526381db36c31a5e660f7087cbce 100644 --- a/client_example/11_grouped_conv_bwd_weight/common.hpp +++ b/client_example/11_grouped_conv_bwd_weight/common.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #include #include @@ -32,63 +32,49 @@ struct SimpleDeviceMem }; template -std::size_t GetFlops(ck::index_t G, - ck::index_t N, - ck::index_t K, - ck::index_t C, - const std::array& output_spatial_lengths, - const std::array& filter_spatial_lengths) +std::size_t GetFlops(const std::array& output_lengths, + const std::array& filter_lengths) { + constexpr ck::index_t spatial_offset = 3; + const auto C = filter_lengths[2]; // 2 * G * N * K * C * * - return static_cast(2) * G * N * K * C * - std::accumulate(std::begin(output_spatial_lengths), - std::end(output_spatial_lengths), + return static_cast(2) * C * + std::accumulate(std::begin(output_lengths), + std::end(output_lengths), static_cast(1), std::multiplies<>()) * - std::accumulate(std::begin(filter_spatial_lengths), - std::end(filter_spatial_lengths), + std::accumulate(std::begin(filter_lengths) + spatial_offset, + std::end(filter_lengths), static_cast(1), std::multiplies<>()); } template -std::size_t GetInputByte(ck::index_t G, - ck::index_t N, - ck::index_t C, - const std::array& input_spatial_lengths) +std::size_t GetInputByte(const std::array& input_lengths) { // sizeof(InDataType) * (G * N * C * ) + - return sizeof(InDataType) * (G * N * C * - std::accumulate(std::begin(input_spatial_lengths), - std::end(input_spatial_lengths), + return sizeof(InDataType) * (std::accumulate(std::begin(input_lengths), + std::end(input_lengths), static_cast(1), std::multiplies<>())); } template -std::size_t GetWeightByte(ck::index_t G, - ck::index_t K, - ck::index_t C, - const std::array& filter_spatial_lengths) +std::size_t GetWeightByte(const std::array& filter_lengths) { // sizeof(WeiDataType) * (G * K * C * ) + - return sizeof(WeiDataType) * (G * K * C * - std::accumulate(std::begin(filter_spatial_lengths), - std::end(filter_spatial_lengths), + return sizeof(WeiDataType) * (std::accumulate(std::begin(filter_lengths), + std::end(filter_lengths), static_cast(1), std::multiplies<>())); } template -std::size_t GetOutputByte(ck::index_t G, - ck::index_t N, - ck::index_t K, - const std::array& output_spatial_lengths) +std::size_t GetOutputByte(const std::array& output_lengths) { // sizeof(OutDataType) * (G * N * K * ); - return sizeof(OutDataType) * (G * N * K * - std::accumulate(std::begin(output_spatial_lengths), - std::end(output_spatial_lengths), + return sizeof(OutDataType) * (std::accumulate(std::begin(output_lengths), + std::end(output_lengths), static_cast(1), std::multiplies())); } @@ -101,13 +87,12 @@ template bool run_grouped_conv_bwd_weight( - ck::index_t G, - ck::index_t N, - ck::index_t K, - ck::index_t C, - const std::array& input_spatial_lengths, - const std::array& filter_spatial_lengths, - const std::array& output_spatial_lengths, + const std::array& input_lengths, + const std::array& input_strides, + const std::array& filter_lengths, + const std::array& weights_strides, + const std::array& output_lengths, + const std::array& output_strides, const std::array& conv_filter_strides, const std::array& conv_filter_dilations, const std::array& input_left_pads, @@ -115,9 +100,9 @@ bool run_grouped_conv_bwd_weight( { ck::index_t split_k = 2; - SimpleDeviceMem in(GetInputByte(G, N, C, input_spatial_lengths)); - SimpleDeviceMem wei(GetWeightByte(G, K, C, filter_spatial_lengths)); - SimpleDeviceMem out(GetOutputByte(G, N, K, output_spatial_lengths)); + SimpleDeviceMem in(GetInputByte(input_lengths)); + SimpleDeviceMem wei(GetWeightByte(filter_lengths)); + SimpleDeviceMem out(GetOutputByte(output_lengths)); using DeviceOp = ck::tensor_operation::device::DeviceGroupedConvBwdWeight a_g_n_c_wis_lengths{}; + std::array a_g_n_c_wis_strides{}; + std::array b_g_k_c_xs_lengths{}; + // profile device operation instances std::cout << "Run all instances and do timing" << std::endl; @@ -150,13 +139,12 @@ bool run_grouped_conv_bwd_weight( auto argument_ptr = op_ptr->MakeArgumentPointer(in.GetDeviceBuffer(), wei.GetDeviceBuffer(), out.GetDeviceBuffer(), - G, - N, - K, - C, - input_spatial_lengths, - filter_spatial_lengths, - output_spatial_lengths, + input_lengths, + input_strides, + filter_lengths, + weights_strides, + output_lengths, + output_strides, conv_filter_strides, conv_filter_dilations, input_left_pads, @@ -172,12 +160,10 @@ bool run_grouped_conv_bwd_weight( { float avg_time = invoker_ptr->Run(argument_ptr.get(), StreamConfig{nullptr, true}); - std::size_t flop = - GetFlops(G, N, K, C, output_spatial_lengths, filter_spatial_lengths); - std::size_t num_bytes = - GetInputByte(G, N, C, input_spatial_lengths) + - GetWeightByte(G, K, C, filter_spatial_lengths) + - GetOutputByte(G, N, K, output_spatial_lengths); + std::size_t flop = GetFlops(output_lengths, filter_lengths); + std::size_t num_bytes = GetInputByte(input_lengths) + + GetWeightByte(filter_lengths) + + GetOutputByte(output_lengths); float tflops = static_cast(flop) / 1.E9 / avg_time; float gb_per_sec = num_bytes / 1.E6 / avg_time; @@ -217,13 +203,12 @@ bool run_grouped_conv_bwd_weight( auto argument_ptr = op_ptr->MakeArgumentPointer(in.GetDeviceBuffer(), wei.GetDeviceBuffer(), out.GetDeviceBuffer(), - G, - N, - K, - C, - input_spatial_lengths, - filter_spatial_lengths, - output_spatial_lengths, + input_lengths, + input_strides, + filter_lengths, + weights_strides, + output_lengths, + output_strides, conv_filter_strides, conv_filter_dilations, input_left_pads, diff --git a/client_example/11_grouped_conv_bwd_weight/grouped_conv1d_bwd_weight_fp16.cpp b/client_example/11_grouped_conv_bwd_weight/grouped_conv1d_bwd_weight_fp16.cpp index 788d50ddefb704d9c08c081ae24f40480882a04f..e6d427faf4d75033a931ec2bfeeb72f4cec90e31 100644 --- a/client_example/11_grouped_conv_bwd_weight/grouped_conv1d_bwd_weight_fp16.cpp +++ b/client_example/11_grouped_conv_bwd_weight/grouped_conv1d_bwd_weight_fp16.cpp @@ -22,6 +22,16 @@ static constexpr ck::index_t C = 192; static constexpr ck::index_t X = 3; static constexpr ck::index_t Wi = 28; static constexpr ck::index_t Wo = 28; +static constexpr std::array input_lengths{G, N, C, Wi}; +static constexpr std::array filter_lengths{G, K, C, X}; +static constexpr std::array output_lengths{G, N, K, Wo}; +static constexpr std::array input_strides{N * Wi * C, Wi* C, 1, C}; +static constexpr std::array weights_strides{K * X * C, X* C, 1, C}; +static constexpr std::array output_strides{N * Wo * K, Wo* K, 1, K}; +static constexpr std::array conv_filter_strides{1}; +static constexpr std::array conv_filter_dilations{1}; +static constexpr std::array input_left_pads{1}; +static constexpr std::array input_right_pads{1}; int main() { @@ -31,7 +41,16 @@ int main() OutDataType, InLayout, WeiLayout, - OutLayout>(G, N, K, C, {Wi}, {X}, {Wo}, {1}, {1}, {1}, {1}) + OutLayout>(input_lengths, + input_strides, + filter_lengths, + weights_strides, + output_lengths, + output_strides, + conv_filter_strides, + conv_filter_dilations, + input_left_pads, + input_right_pads) ? EXIT_SUCCESS : EXIT_FAILURE; } diff --git a/client_example/11_grouped_conv_bwd_weight/grouped_conv2d_bwd_weight_fp16.cpp b/client_example/11_grouped_conv_bwd_weight/grouped_conv2d_bwd_weight_fp16.cpp index 1903bd95b67b0834c73047cde0cdae6fe4e7fe82..4201ea61b4ee3c62f0bb8034b6e0510242000942 100644 --- a/client_example/11_grouped_conv_bwd_weight/grouped_conv2d_bwd_weight_fp16.cpp +++ b/client_example/11_grouped_conv_bwd_weight/grouped_conv2d_bwd_weight_fp16.cpp @@ -25,6 +25,19 @@ static constexpr ck::index_t Hi = 28; static constexpr ck::index_t Wi = 28; static constexpr ck::index_t Ho = 28; static constexpr ck::index_t Wo = 28; +static constexpr std::array input_lengths{G, N, C, Hi, Wi}; +static constexpr std::array filter_lengths{G, K, C, Y, X}; +static constexpr std::array output_lengths{G, N, K, Ho, Wo}; +static constexpr std::array input_strides{ + N * Hi * Wi * C, Hi* Wi* C, 1, Wi* C, C}; +static constexpr std::array weights_strides{ + K * Y * X * C, Y* X* C, 1, X* C, C}; +static constexpr std::array output_strides{ + N * Ho * Wo * K, Ho* Wo* K, 1, Wo* K, K}; +static constexpr std::array conv_filter_strides{1, 1}; +static constexpr std::array conv_filter_dilations{1, 1}; +static constexpr std::array input_left_pads{1, 1}; +static constexpr std::array input_right_pads{1, 1}; int main() { @@ -34,8 +47,16 @@ int main() OutDataType, InLayout, WeiLayout, - OutLayout>( - G, N, K, C, {Hi, Wi}, {Y, X}, {Ho, Wo}, {1, 1}, {1, 1}, {1, 1}, {1, 1}) + OutLayout>(input_lengths, + input_strides, + filter_lengths, + weights_strides, + output_lengths, + output_strides, + conv_filter_strides, + conv_filter_dilations, + input_left_pads, + input_right_pads) ? EXIT_SUCCESS : EXIT_FAILURE; } diff --git a/client_example/11_grouped_conv_bwd_weight/grouped_conv3d_bwd_weight_fp16.cpp b/client_example/11_grouped_conv_bwd_weight/grouped_conv3d_bwd_weight_fp16.cpp index 2f2b5d4e2113c090c0e171a4204f3f9cfde2d276..3ae46bcd5566cd3958e1b4e666119ec55e3528e9 100644 --- a/client_example/11_grouped_conv_bwd_weight/grouped_conv3d_bwd_weight_fp16.cpp +++ b/client_example/11_grouped_conv_bwd_weight/grouped_conv3d_bwd_weight_fp16.cpp @@ -28,6 +28,19 @@ static constexpr ck::index_t Wi = 3; static constexpr ck::index_t Do = 28; static constexpr ck::index_t Ho = 28; static constexpr ck::index_t Wo = 3; +static constexpr std::array input_lengths{G, N, C, Di, Hi, Wi}; +static constexpr std::array filter_lengths{G, K, C, Z, Y, X}; +static constexpr std::array output_lengths{G, N, K, Do, Ho, Wo}; +static constexpr std::array input_strides{ + N * Di * Hi * Wi * C, Di* Hi* Wi* C, 1, Hi* Wi* C, Wi* C, C}; +static constexpr std::array weights_strides{ + K * Z * Y * X * C, Z* Y* X* C, 1, Y* X* C, X* C, C}; +static constexpr std::array output_strides{ + N * Do * Ho * Wo * K, Do* Ho* Wo* K, 1, Ho* Wo* K, Wo* K, K}; +static constexpr std::array conv_filter_strides{1, 1, 1}; +static constexpr std::array conv_filter_dilations{1, 1, 1}; +static constexpr std::array input_left_pads{1, 1, 1}; +static constexpr std::array input_right_pads{1, 1, 1}; int main() { @@ -37,17 +50,16 @@ int main() OutDataType, InLayout, WeiLayout, - OutLayout>(G, - N, - K, - C, - {Di, Hi, Wi}, - {Z, Y, X}, - {Do, Ho, Wo}, - {1, 1, 1}, - {1, 1, 1}, - {1, 1, 1}, - {1, 1, 1}) + OutLayout>(input_lengths, + input_strides, + filter_lengths, + weights_strides, + output_lengths, + output_strides, + conv_filter_strides, + conv_filter_dilations, + input_left_pads, + input_right_pads) ? EXIT_SUCCESS : EXIT_FAILURE; } diff --git a/client_example/11_grouped_conv_bwd_weight/grouped_conv3d_bwd_weight_fp32.cpp b/client_example/11_grouped_conv_bwd_weight/grouped_conv3d_bwd_weight_fp32.cpp index 796311d2318e7d86198fd8ea631680876028c985..2eb869f3923ad0d743d5d0061842683e2fd95f87 100644 --- a/client_example/11_grouped_conv_bwd_weight/grouped_conv3d_bwd_weight_fp32.cpp +++ b/client_example/11_grouped_conv_bwd_weight/grouped_conv3d_bwd_weight_fp32.cpp @@ -28,6 +28,19 @@ static constexpr ck::index_t Wi = 3; static constexpr ck::index_t Do = 28; static constexpr ck::index_t Ho = 28; static constexpr ck::index_t Wo = 3; +static constexpr std::array input_lengths{G, N, C, Di, Hi, Wi}; +static constexpr std::array filter_lengths{G, K, C, Z, Y, X}; +static constexpr std::array output_lengths{G, N, K, Do, Ho, Wo}; +static constexpr std::array input_strides{ + N * Di * Hi * Wi * C, Di* Hi* Wi* C, 1, Hi* Wi* C, Wi* C, C}; +static constexpr std::array weights_strides{ + K * Z * Y * X * C, Z* Y* X* C, 1, Y* X* C, X* C, C}; +static constexpr std::array output_strides{ + N * Do * Ho * Wo * K, Do* Ho* Wo* K, 1, Ho* Wo* K, Wo* K, K}; +static constexpr std::array conv_filter_strides{1, 1, 1}; +static constexpr std::array conv_filter_dilations{1, 1, 1}; +static constexpr std::array input_left_pads{1, 1, 1}; +static constexpr std::array input_right_pads{1, 1, 1}; int main() { @@ -37,17 +50,16 @@ int main() OutDataType, InLayout, WeiLayout, - OutLayout>(G, - N, - K, - C, - {Di, Hi, Wi}, - {Z, Y, X}, - {Do, Ho, Wo}, - {1, 1, 1}, - {1, 1, 1}, - {1, 1, 1}, - {1, 1, 1}) + OutLayout>(input_lengths, + input_strides, + filter_lengths, + weights_strides, + output_lengths, + output_strides, + conv_filter_strides, + conv_filter_dilations, + input_left_pads, + input_right_pads) ? EXIT_SUCCESS : EXIT_FAILURE; } diff --git a/client_example/12_elementwise_normalization/elementwise_layernorm2d.cpp b/client_example/12_elementwise_normalization/elementwise_layernorm2d.cpp index de68f46d398958917e49bf14178f66414590ed86..bc4a6fe0bfa9e118bbd6ba32ecc7dd68f3b8b2c3 100644 --- a/client_example/12_elementwise_normalization/elementwise_layernorm2d.cpp +++ b/client_example/12_elementwise_normalization/elementwise_layernorm2d.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #include #include diff --git a/client_example/13_batchnorm/batchnorm_bwd_nhwc.cpp b/client_example/13_batchnorm/batchnorm_bwd_nhwc.cpp index 8ef21986a4d9a5f1eb25e21a1073c4cc341da88d..1ed36e0f50f05483a938a9a7b1c433df003b3b39 100644 --- a/client_example/13_batchnorm/batchnorm_bwd_nhwc.cpp +++ b/client_example/13_batchnorm/batchnorm_bwd_nhwc.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #include #include @@ -191,6 +191,12 @@ int main(int argc, char* argv[]) if(op_ptr->IsSupportedArgument(argument_ptr.get())) { + size_t workspace_sz = op_ptr->GetWorkSpaceSize(argument_ptr.get()); + + SimpleDeviceMem workspace(workspace_sz); + + op_ptr->SetWorkSpacePointer(argument_ptr.get(), workspace.GetDeviceBuffer()); + invoker_ptr->Run(argument_ptr.get(), StreamConfig{nullptr, false}); } diff --git a/client_example/13_batchnorm/batchnorm_fwd_nhwc.cpp b/client_example/13_batchnorm/batchnorm_fwd_nhwc.cpp index 322667a46bacae8d0c681939c3890ef9ff476b0e..f9af011c8480a5caa510a66467edc33e886fe0e2 100644 --- a/client_example/13_batchnorm/batchnorm_fwd_nhwc.cpp +++ b/client_example/13_batchnorm/batchnorm_fwd_nhwc.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #include #include @@ -187,6 +187,12 @@ int main(int argc, char* argv[]) if(op_ptr->IsSupportedArgument(argument_ptr.get())) { + size_t workspace_sz = op_ptr->GetWorkSpaceSize(argument_ptr.get()); + + SimpleDeviceMem workspace(workspace_sz); + + op_ptr->SetWorkSpacePointer(argument_ptr.get(), workspace.GetDeviceBuffer()); + invoker_ptr->Run(argument_ptr.get(), StreamConfig{nullptr, false}); } diff --git a/client_example/13_batchnorm/batchnorm_infer_nhwc.cpp b/client_example/13_batchnorm/batchnorm_infer_nhwc.cpp index 3117d162db71a0a4d02a15decac20e1f0f50d56e..5e6627ce14d113224c7b0acb4fea69cb36c1f369 100644 --- a/client_example/13_batchnorm/batchnorm_infer_nhwc.cpp +++ b/client_example/13_batchnorm/batchnorm_infer_nhwc.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #include #include diff --git a/client_example/14_instance_id/batchnorm_fwd_instance_id.cpp b/client_example/14_instance_id/batchnorm_fwd_instance_id.cpp index 9cfeee1cfe106e69f83dc0184f3956d6751a2947..d45782d8e0ff37027a204c6820447286581f1138 100644 --- a/client_example/14_instance_id/batchnorm_fwd_instance_id.cpp +++ b/client_example/14_instance_id/batchnorm_fwd_instance_id.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #include #include diff --git a/client_example/15_gemm_add_multiply/gemm_add_multiply.cpp b/client_example/15_gemm_add_multiply/gemm_add_multiply.cpp index 28524a9eee9b87db4223484497f8e2181ed366ea..c74d7c6bd8cf9cec54168792a8718768827d7b33 100644 --- a/client_example/15_gemm_add_multiply/gemm_add_multiply.cpp +++ b/client_example/15_gemm_add_multiply/gemm_add_multiply.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #include #include diff --git a/client_example/15_reduce/reduce_nhwc_c.cpp b/client_example/15_reduce/reduce_nhwc_c.cpp index 2275158bcb26d36d871c9e7086b45d4539584785..b45b72f0de0199daa88e9e42be320e2669398dfe 100644 --- a/client_example/15_reduce/reduce_nhwc_c.cpp +++ b/client_example/15_reduce/reduce_nhwc_c.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #include #include diff --git a/client_example/16_convnd_fwd/common.hpp b/client_example/16_convnd_fwd/common.hpp index a6bb5aa65be098716b504ae9d0543ed9ae129dc0..449c9466e829baa8658f17fe8e9ae86c5a297238 100644 --- a/client_example/16_convnd_fwd/common.hpp +++ b/client_example/16_convnd_fwd/common.hpp @@ -141,14 +141,10 @@ bool run_grouped_conv_fwd(std::array( - {N, Di, Hi, Wi, G, C}, {K, Z, Y, X, G, C}, {N, Do, Ho, Wo, G, K}) + {N, Di, Hi, Wi, G, C}, {G, K, Z, Y, X, C}, {N, Do, Ho, Wo, G, K}) ? EXIT_SUCCESS : EXIT_FAILURE; } diff --git a/client_example/16_convnd_fwd/conv3d_fwd_fp32.cpp b/client_example/16_convnd_fwd/conv3d_fwd_fp32.cpp index 43c98f1e9b8ae8b89893d5e430d4f4ab89678803..7e8c98b6037418001571707a82aa14987059fb4f 100644 --- a/client_example/16_convnd_fwd/conv3d_fwd_fp32.cpp +++ b/client_example/16_convnd_fwd/conv3d_fwd_fp32.cpp @@ -11,7 +11,7 @@ using WeiDataType = float; using OutDataType = float; using InLayout = ck::tensor_layout::convolution::NDHWGC; -using WeiLayout = ck::tensor_layout::convolution::KZYXGC; +using WeiLayout = ck::tensor_layout::convolution::GKZYXC; using OutLayout = ck::tensor_layout::convolution::NDHWGK; static constexpr ck::index_t NumDimSpatial = 3; @@ -38,7 +38,7 @@ int main() InLayout, WeiLayout, OutLayout>( - {N, Di, Hi, Wi, G, C}, {K, Z, Y, X, G, C}, {N, Do, Ho, Wo, G, K}) + {N, Di, Hi, Wi, G, C}, {G, K, Z, Y, X, C}, {N, Do, Ho, Wo, G, K}) ? EXIT_SUCCESS : EXIT_FAILURE; } diff --git a/client_example/17_grouped_gemm_fastgelu/grouped_gemm_fastgelu.cpp b/client_example/17_grouped_gemm_fastgelu/grouped_gemm_fastgelu.cpp index 223ed29be9ac66e8805d8d5b3e696bbac4f5c922..7ba3224fc3244ab3bd472b98864efa42ae132fde 100644 --- a/client_example/17_grouped_gemm_fastgelu/grouped_gemm_fastgelu.cpp +++ b/client_example/17_grouped_gemm_fastgelu/grouped_gemm_fastgelu.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #include #include diff --git a/client_example/18_groupnorm/groupnorm_swish.cpp b/client_example/18_groupnorm/groupnorm_swish.cpp index 84f62ceac437cb1cf950388e17a2af4487dda18f..df0a9ceec61715526ddea288472eb7ef9144ae3d 100644 --- a/client_example/18_groupnorm/groupnorm_swish.cpp +++ b/client_example/18_groupnorm/groupnorm_swish.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #include #include @@ -72,6 +72,30 @@ int main(int argc, char* argv[]) std::cout << "found " << op_ptrs.size() << " instances" << std::endl; + const auto& generic_op_ptr = op_ptrs[0]; + + auto generic_argument_ptr = + generic_op_ptr->MakeArgumentPointer({N, H, W, G, C}, // lengths + xy_strides, // xStrides + gamma_beta_strides, // gammaStrides + gamma_beta_strides, // betaStrides + xy_strides, // yStrides + {1, 2, 4}, // reduceDims + 1e-6, + x_device_buf.GetDeviceBuffer(), + gamma_device_buf.GetDeviceBuffer(), + beta_device_buf.GetDeviceBuffer(), + y_device_buf.GetDeviceBuffer(), + nullptr, + nullptr, + Swish{}); + + if(!generic_op_ptr->IsSupportedArgument(generic_argument_ptr.get())) + { + throw std::runtime_error( + "The generic kernel instance should be able to support any input shapes"); + }; + std::string best_op_name; bool found = false; int best_op_id = -1; @@ -105,6 +129,10 @@ int main(int argc, char* argv[]) if(op_ptr->IsSupportedArgument(argument_ptr.get())) { + size_t workspace_sz = op_ptr->GetWorkSpaceSize(argument_ptr.get()); + SimpleDeviceMem workspace(workspace_sz); + op_ptr->SetWorkSpacePointer(argument_ptr.get(), workspace.GetDeviceBuffer()); + float ave_time = invoker_ptr->Run(argument_ptr.get(), StreamConfig{nullptr, true}); std::size_t num_byte = @@ -160,6 +188,10 @@ int main(int argc, char* argv[]) if(op_ptr->IsSupportedArgument(argument_ptr.get())) { + size_t workspace_sz = op_ptr->GetWorkSpaceSize(argument_ptr.get()); + SimpleDeviceMem workspace(workspace_sz); + op_ptr->SetWorkSpacePointer(argument_ptr.get(), workspace.GetDeviceBuffer()); + invoker_ptr->Run(argument_ptr.get(), StreamConfig{nullptr, false}); } diff --git a/client_example/19_pool/CMakeLists.txt b/client_example/19_pool/CMakeLists.txt new file mode 100644 index 0000000000000000000000000000000000000000..d4e2e6d4dc2d18c551e6252af631404b396943d5 --- /dev/null +++ b/client_example/19_pool/CMakeLists.txt @@ -0,0 +1,11 @@ +add_executable(client_max_pool2d_fwd max_pool2d_fwd.cpp) +target_link_libraries(client_max_pool2d_fwd PRIVATE composable_kernel::device_operations) + +add_executable(client_max_pool2d_bwd max_pool2d_bwd.cpp) +target_link_libraries(client_max_pool2d_bwd PRIVATE composable_kernel::device_operations) + +add_executable(client_avg_pool3d_fwd avg_pool3d_fwd.cpp) +target_link_libraries(client_avg_pool3d_fwd PRIVATE composable_kernel::device_operations) + +add_executable(client_avg_pool3d_bwd avg_pool3d_bwd.cpp) +target_link_libraries(client_avg_pool3d_bwd PRIVATE composable_kernel::device_operations) diff --git a/client_example/19_pool/avg_pool3d_bwd.cpp b/client_example/19_pool/avg_pool3d_bwd.cpp new file mode 100644 index 0000000000000000000000000000000000000000..686d1da3ad61a5cb9ca4c8bdec034591cd354c55 --- /dev/null +++ b/client_example/19_pool/avg_pool3d_bwd.cpp @@ -0,0 +1,191 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#include +#include +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" + +#include "ck/library/tensor_operation_instance/gpu/avg_pool3d_bwd.hpp" + +using DOutDataType = ck::half_t; +using DInDataType = ck::half_t; + +using DOutLayout = ck::tensor_layout::convolution::NDHWC; +using DInLayout = ck::tensor_layout::convolution::NDHWC; + +struct SimpleDeviceMem +{ + SimpleDeviceMem() = delete; + + SimpleDeviceMem(std::size_t mem_size) : p_mem_{}, mMemSize_(mem_size) + { + (void)hipMalloc(static_cast(&p_mem_), mem_size); + } + + void* GetDeviceBuffer() { return p_mem_; } + + void SetZero() const { (void)hipMemset(p_mem_, 0, mMemSize_); } + + ~SimpleDeviceMem() { (void)hipFree(p_mem_); } + + void* p_mem_; + std::size_t mMemSize_; +}; + +int main(int argc, char* argv[]) +{ + ck::index_t N = 2; + ck::index_t C = 32; + ck::index_t Z = 2; + ck::index_t Y = 2; + ck::index_t X = 2; + ck::index_t Di = 30; + ck::index_t Hi = 30; + ck::index_t Wi = 30; + ck::index_t window_stride_d = 2; + ck::index_t window_stride_h = 2; + ck::index_t window_stride_w = 2; + ck::index_t window_dilation_d = 1; + ck::index_t window_dilation_h = 1; + ck::index_t window_dilation_w = 1; + ck::index_t in_left_pad_d = 1; + ck::index_t in_left_pad_h = 1; + ck::index_t in_left_pad_w = 1; + ck::index_t in_right_pad_d = 1; + ck::index_t in_right_pad_h = 1; + ck::index_t in_right_pad_w = 1; + + const ck::index_t Zs = (Z - 1) * window_dilation_d + 1; + const ck::index_t Ys = (Y - 1) * window_dilation_h + 1; + const ck::index_t Xs = (X - 1) * window_dilation_w + 1; + ck::index_t Do = (Di + in_left_pad_d + in_right_pad_d - Zs) / window_stride_d + 1; + ck::index_t Ho = (Hi + in_left_pad_h + in_right_pad_h - Ys) / window_stride_h + 1; + ck::index_t Wo = (Wi + in_left_pad_w + in_right_pad_w - Xs) / window_stride_w + 1; + + // Pool API only support the order of NCDHW + std::vector in_length = {N, C, Di, Hi, Wi}; + std::vector out_length = {N, C, Do, Ho, Wo}; + std::vector window_spatial_lengths = {Z, Y, X}; + std::vector window_strides = {window_stride_d, window_stride_h, window_stride_w}; + std::vector window_dilations{ + window_dilation_d, window_dilation_h, window_dilation_w}; + std::vector input_left_pads = {in_left_pad_d, in_left_pad_h, in_left_pad_w}; + std::vector input_right_pads = {in_right_pad_d, in_right_pad_h, in_right_pad_w}; + + std::size_t in_tensor_size = N * C * Di * Hi * Wi; + std::size_t out_tensor_size = N * C * Do * Ho * Wo; + + // tensor layout = NDHWC + std::vector in_tensor_stride = {Di * C * Hi * Wi, 1, C * Hi * Wi, Wi * C, C}; + std::vector out_tensor_stride = {Do * C * Ho * Wo, 1, C * Ho * Wo, Wo * C, C}; + + SimpleDeviceMem dout_device_buf(sizeof(DOutDataType) * out_tensor_size); + SimpleDeviceMem din_device_buf(sizeof(DInDataType) * in_tensor_size); + + using DeviceOp = ck::tensor_operation::device:: + DeviceAvgPoolBwd<3, DOutDataType, DInDataType, DOutLayout, DInLayout>; + + // get device op instances + const auto op_ptrs = ck::tensor_operation::device::instance::DeviceOperationInstanceFactory< + DeviceOp>::GetInstances(); + + std::cout << "found " << op_ptrs.size() << " instances" << std::endl; + + std::string best_op_name; + bool found = false; + int best_op_id = -1; + float best_ave_time = std::numeric_limits::max(); + float best_gb_per_sec = 0; + + // profile device operation instances + std::cout << "Run all instances and do timing" << std::endl; + + for(int i = 0; i < op_ptrs.size(); ++i) + { + auto& op_ptr = op_ptrs[i]; + auto argument_ptr = op_ptr->MakeArgumentPointer( + static_cast(dout_device_buf.GetDeviceBuffer()), + static_cast(din_device_buf.GetDeviceBuffer()), + out_length, + in_length, + out_tensor_stride, + in_tensor_stride, + window_spatial_lengths, + window_strides, + window_dilations, + input_left_pads, + input_right_pads); + + auto invoker_ptr = op_ptr->MakeInvokerPointer(); + + std::string op_name = op_ptr->GetTypeString(); + + if(op_ptr->IsSupportedArgument(argument_ptr.get())) + { + din_device_buf.SetZero(); + + float ave_time = invoker_ptr->Run(argument_ptr.get(), StreamConfig{nullptr, true}); + + std::size_t num_bytes = + in_tensor_size * sizeof(DInDataType) + out_tensor_size * sizeof(DOutDataType); + + float gb_per_sec = num_bytes / 1.E6 / ave_time; + + std::cout << "Perf: " << std::setw(10) << ave_time << " ms, " << gb_per_sec << " GB/s, " + << op_name << std::endl; + + if(ave_time < best_ave_time) + { + found = true; + best_op_id = i; + best_op_name = op_name; + best_ave_time = ave_time; + best_gb_per_sec = gb_per_sec; + } + } + else + { + std::cout << op_name << " does not support this problem" << std::endl; + } + } + + // run the best intance + if(found) + { + std::cout << "Best Perf: " << best_ave_time << " ms, " << best_gb_per_sec << " GB/s, " + << best_op_name << std::endl; + + auto& op_ptr = op_ptrs[best_op_id]; + std::cout << "Run the best instance without timing: " << op_ptr->GetTypeString() + << std::endl; + + auto argument_ptr = op_ptr->MakeArgumentPointer( + static_cast(dout_device_buf.GetDeviceBuffer()), + static_cast(din_device_buf.GetDeviceBuffer()), + out_length, + in_length, + out_tensor_stride, + in_tensor_stride, + window_spatial_lengths, + window_strides, + window_dilations, + input_left_pads, + input_right_pads); + + auto invoker_ptr = op_ptr->MakeInvokerPointer(); + + if(op_ptr->IsSupportedArgument(argument_ptr.get())) + { + din_device_buf.SetZero(); + invoker_ptr->Run(argument_ptr.get(), StreamConfig{nullptr, false}); + } + + std::cout << "Done" << std::endl; + } + + return 0; +} diff --git a/client_example/19_pool_fwd/avg_pool3d_fwd.cpp b/client_example/19_pool/avg_pool3d_fwd.cpp similarity index 77% rename from client_example/19_pool_fwd/avg_pool3d_fwd.cpp rename to client_example/19_pool/avg_pool3d_fwd.cpp index 47bd7738f69189b95541a0722c7c1b43914fd98b..db8e0569d717f08e9b64e4504d40dd23aab13129 100644 --- a/client_example/19_pool_fwd/avg_pool3d_fwd.cpp +++ b/client_example/19_pool/avg_pool3d_fwd.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #include #include @@ -16,6 +16,9 @@ using InDataType = ck::half_t; using OutDataType = ck::half_t; using IndexDataType = int32_t; +using InLayout = ck::tensor_layout::convolution::NDHWC; +using OutLayout = ck::tensor_layout::convolution::NDHWC; + constexpr ck::index_t InOutRank = 5; constexpr ck::index_t WindowRank = 3; #if 0 @@ -44,33 +47,41 @@ struct SimpleDeviceMem int main(int argc, char* argv[]) { - ck::index_t N = 2; - ck::index_t C = 32; - ck::index_t Z = 2; - ck::index_t Y = 2; - ck::index_t X = 2; - ck::index_t Di = 30; - ck::index_t Hi = 30; - ck::index_t Wi = 30; - ck::index_t window_stride_d = 2; - ck::index_t window_stride_h = 2; - ck::index_t window_stride_w = 2; - ck::index_t in_left_pad_d = 1; - ck::index_t in_left_pad_h = 1; - ck::index_t in_left_pad_w = 1; - ck::index_t in_right_pad_d = 1; - ck::index_t in_right_pad_h = 1; - ck::index_t in_right_pad_w = 1; - - ck::index_t Do = (Di + in_left_pad_d + in_right_pad_d - Z) / window_stride_d + 1; - ck::index_t Ho = (Hi + in_left_pad_h + in_right_pad_h - Y) / window_stride_h + 1; - ck::index_t Wo = (Wi + in_left_pad_w + in_right_pad_w - X) / window_stride_w + 1; + ck::index_t N = 2; + ck::index_t C = 32; + ck::index_t Z = 2; + ck::index_t Y = 2; + ck::index_t X = 2; + ck::index_t Di = 30; + ck::index_t Hi = 30; + ck::index_t Wi = 30; + ck::index_t window_stride_d = 2; + ck::index_t window_stride_h = 2; + ck::index_t window_stride_w = 2; + ck::index_t window_dilation_d = 1; + ck::index_t window_dilation_h = 1; + ck::index_t window_dilation_w = 1; + ck::index_t in_left_pad_d = 1; + ck::index_t in_left_pad_h = 1; + ck::index_t in_left_pad_w = 1; + ck::index_t in_right_pad_d = 1; + ck::index_t in_right_pad_h = 1; + ck::index_t in_right_pad_w = 1; + + const ck::index_t Zs = (Z - 1) * window_dilation_d + 1; + const ck::index_t Ys = (Y - 1) * window_dilation_h + 1; + const ck::index_t Xs = (X - 1) * window_dilation_w + 1; + ck::index_t Do = (Di + in_left_pad_d + in_right_pad_d - Zs) / window_stride_d + 1; + ck::index_t Ho = (Hi + in_left_pad_h + in_right_pad_h - Ys) / window_stride_h + 1; + ck::index_t Wo = (Wi + in_left_pad_w + in_right_pad_w - Xs) / window_stride_w + 1; // Pool API only support the order of NCDHW std::vector in_length = {N, C, Di, Hi, Wi}; std::vector out_length = {N, C, Do, Ho, Wo}; std::vector window_spatial_lengths = {Z, Y, X}; - std::vector window_strides = {window_stride_d, window_stride_h, window_stride_w}; + std::vector window_strides = {window_stride_d, window_stride_h, window_stride_w}; + std::vector window_dilations{ + window_dilation_d, window_dilation_h, window_dilation_w}; std::vector input_left_pads = {in_left_pad_d, in_left_pad_h, in_left_pad_w}; std::vector input_right_pads = {in_right_pad_d, in_right_pad_h, in_right_pad_w}; @@ -90,6 +101,8 @@ int main(int argc, char* argv[]) InDataType, OutDataType, IndexDataType, + InLayout, + OutLayout, ReduceOpId, OutputIndex>; @@ -122,6 +135,7 @@ int main(int argc, char* argv[]) out_tensor_stride, out_tensor_stride, window_strides, + window_dilations, input_left_pads, input_right_pads, {2, 3, 4}); @@ -181,6 +195,7 @@ int main(int argc, char* argv[]) out_tensor_stride, out_tensor_stride, window_strides, + window_dilations, input_left_pads, input_right_pads, {2, 3, 4}); diff --git a/client_example/19_pool/max_pool2d_bwd.cpp b/client_example/19_pool/max_pool2d_bwd.cpp new file mode 100644 index 0000000000000000000000000000000000000000..53ece7425f99bd6eadc6de1a6f041a12fa0951d0 --- /dev/null +++ b/client_example/19_pool/max_pool2d_bwd.cpp @@ -0,0 +1,280 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#include +#include +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/device_pool_fwd.hpp" +#include "ck/tensor_operation/gpu/device/device_max_pool_bwd.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" + +#include "ck/library/tensor_operation_instance/gpu/pool3d_fwd.hpp" +#include "ck/library/tensor_operation_instance/gpu/max_pool_bwd.hpp" + +using InDataType = ck::half_t; +using OutDataType = ck::half_t; +using DOutDataType = ck::half_t; +using DInDataType = ck::half_t; +using IndexDataType = int32_t; + +// We use pool3d to implement pool2d in this example +using InLayout = ck::tensor_layout::convolution::NDHWC; +using OutLayout = ck::tensor_layout::convolution::NDHWC; + +constexpr ck::index_t InOutRank = 5; +constexpr ck::index_t WindowRank = 3; + +struct SimpleDeviceMem +{ + SimpleDeviceMem() = delete; + + SimpleDeviceMem(std::size_t mem_size) : p_mem_{} + { + (void)hipMalloc(static_cast(&p_mem_), mem_size); + } + + void* GetDeviceBuffer() { return p_mem_; } + + ~SimpleDeviceMem() { (void)hipFree(p_mem_); } + + void* p_mem_; +}; + +void TransformPool2dparamToPool3d(std::vector& input_lengths, + std::vector& window_lengths, + std::vector& output_lengths, + std::vector& input_stride, + std::vector& output_stride, + std::vector& indices_stride, + std::vector& window_strides, + std::vector& window_dilations, + std::vector& input_left_pads, + std::vector& input_right_pads, + std::vector& pooling_dims) +{ + // NCHW to NCDHW + input_lengths.insert(input_lengths.begin() + 2, 1); + output_lengths.insert(output_lengths.begin() + 2, 1); + input_stride.insert(input_stride.begin() + 2, 0); + output_stride.insert(output_stride.begin() + 2, 0); + indices_stride.insert(indices_stride.begin() + 2, 0); + + // YX to ZYX + window_lengths.insert(window_lengths.begin(), 1); + window_strides.insert(window_strides.begin(), 0); + window_dilations.insert(window_dilations.begin(), 0); + input_left_pads.insert(input_left_pads.begin(), 0); + input_right_pads.insert(input_right_pads.begin(), 0); + + pooling_dims = {2, 3, 4}; +} + +int main(int argc, char* argv[]) +{ + ck::index_t N = 2; + ck::index_t C = 32; + ck::index_t Y = 2; + ck::index_t X = 2; + ck::index_t Hi = 30; + ck::index_t Wi = 30; + ck::index_t window_stride_h = 2; + ck::index_t window_stride_w = 2; + ck::index_t window_dilation_h = 1; + ck::index_t window_dilation_w = 1; + ck::index_t in_left_pad_h = 1; + ck::index_t in_left_pad_w = 1; + ck::index_t in_right_pad_h = 1; + ck::index_t in_right_pad_w = 1; + + const ck::index_t Ys = (Y - 1) * window_dilation_h + 1; + const ck::index_t Xs = (X - 1) * window_dilation_w + 1; + ck::index_t Ho = (Hi + in_left_pad_h + in_right_pad_h - Ys) / window_stride_h + 1; + ck::index_t Wo = (Wi + in_left_pad_w + in_right_pad_w - Xs) / window_stride_w + 1; + + // Pool API only support the order of NCHW + std::vector in_length = {N, C, Hi, Wi}; + std::vector out_length = {N, C, Ho, Wo}; + std::vector window_spatial_lengths = {Y, X}; + std::vector window_strides = {window_stride_h, window_stride_w}; + std::vector window_dilations = {window_dilation_h, window_dilation_w}; + std::vector input_left_pads = {in_left_pad_h, in_left_pad_w}; + std::vector input_right_pads = {in_right_pad_h, in_right_pad_w}; + std::vector pooling_dims = {2, 3}; + + std::size_t in_tensor_size = N * C * Hi * Wi; + std::size_t out_tensor_size = N * C * Ho * Wo; + + // tensor layout = NHWC + std::vector in_tensor_stride = {C * Hi * Wi, 1, Wi * C, C}; + std::vector out_tensor_stride = {C * Ho * Wo, 1, Wo * C, C}; + + TransformPool2dparamToPool3d(in_length, + window_spatial_lengths, + out_length, + in_tensor_stride, + out_tensor_stride, + out_tensor_stride, + window_strides, + window_dilations, + input_left_pads, + input_right_pads, + pooling_dims); + + SimpleDeviceMem in_device_buf(sizeof(InDataType) * in_tensor_size); + SimpleDeviceMem out_device_buf(sizeof(OutDataType) * out_tensor_size); + SimpleDeviceMem indices_device_buf(sizeof(IndexDataType) * out_tensor_size); + SimpleDeviceMem dout_device_buf(sizeof(DOutDataType) * out_tensor_size); + SimpleDeviceMem din_device_buf(sizeof(DInDataType) * in_tensor_size); + + // Generate index data from max pool forward + { + using MaxPoolFwdDeviceOp = + ck::tensor_operation::device::DevicePoolFwd; + + const auto op_ptrs = ck::tensor_operation::device::instance::DeviceOperationInstanceFactory< + MaxPoolFwdDeviceOp>::GetInstances(); + + auto& op_ptr = op_ptrs[0]; + + auto argument_ptr = op_ptr->MakeArgumentPointer( + static_cast(in_device_buf.GetDeviceBuffer()), + static_cast(out_device_buf.GetDeviceBuffer()), + static_cast(indices_device_buf.GetDeviceBuffer()), + in_length, + window_spatial_lengths, + out_length, + in_tensor_stride, + out_tensor_stride, + out_tensor_stride, + window_strides, + window_dilations, + input_left_pads, + input_right_pads, + pooling_dims); + + auto invoker_ptr = op_ptr->MakeInvokerPointer(); + + std::string op_name = op_ptr->GetTypeString(); + + if(op_ptr->IsSupportedArgument(argument_ptr.get())) + invoker_ptr->Run(argument_ptr.get(), StreamConfig{nullptr, true}); + } + + // Run MaxPool bwd + using MaxPoolBwdDeviceOp = + ck::tensor_operation::device::DeviceMaxPoolBwd; + + const auto op_ptrs = ck::tensor_operation::device::instance::DeviceOperationInstanceFactory< + MaxPoolBwdDeviceOp>::GetInstances(); + + std::cout << "found " << op_ptrs.size() << " instances" << std::endl; + + std::string best_op_name; + bool found = false; + int best_op_id = -1; + float best_ave_time = std::numeric_limits::max(); + float best_gb_per_sec = 0; + + // profile device operation instances + std::cout << "Run all instances and do timing" << std::endl; + + for(int i = 0; i < op_ptrs.size(); ++i) + { + auto& op_ptr = op_ptrs[i]; + auto argument_ptr = op_ptr->MakeArgumentPointer( + static_cast(dout_device_buf.GetDeviceBuffer()), + static_cast(indices_device_buf.GetDeviceBuffer()), + static_cast(din_device_buf.GetDeviceBuffer()), + out_tensor_size, + in_tensor_size, + window_spatial_lengths, + window_strides, + window_dilations); + + auto invoker_ptr = op_ptr->MakeInvokerPointer(); + + std::string op_name = op_ptr->GetTypeString(); + + if(op_ptr->IsSupportedArgument(argument_ptr.get())) + { + size_t workspace_sz = op_ptr->GetWorkSpaceSize(argument_ptr.get()); + + SimpleDeviceMem workspace(workspace_sz); + + op_ptr->SetWorkSpacePointer(argument_ptr.get(), workspace.GetDeviceBuffer()); + + float ave_time = invoker_ptr->Run(argument_ptr.get(), StreamConfig{nullptr, true}); + + std::size_t num_bytes = in_tensor_size * sizeof(DInDataType) + + out_tensor_size * sizeof(IndexDataType) + + out_tensor_size * sizeof(DOutDataType); + + float gb_per_sec = num_bytes / 1.E6 / ave_time; + + std::cout << "Perf: " << std::setw(10) << ave_time << " ms, " << gb_per_sec << "GB / s," + << op_name << std::endl; + + if(ave_time < best_ave_time) + { + found = true; + best_op_id = i; + best_op_name = op_name; + best_ave_time = ave_time; + best_gb_per_sec = gb_per_sec; + } + } + else + { + std::cout << op_name << " does not support this problem" << std::endl; + } + } + + // run the best intance + if(found) + { + std::cout << "Best Perf: " << best_ave_time << " ms, " << best_gb_per_sec << " GB/s, " + << best_op_name << std::endl; + + auto& op_ptr = op_ptrs[best_op_id]; + std::cout << "Run the best instance without timing: " << op_ptr->GetTypeString() + << std::endl; + + auto argument_ptr = op_ptr->MakeArgumentPointer( + static_cast(dout_device_buf.GetDeviceBuffer()), + static_cast(indices_device_buf.GetDeviceBuffer()), + static_cast(din_device_buf.GetDeviceBuffer()), + out_tensor_size, + in_tensor_size, + window_spatial_lengths, + window_strides, + window_dilations); + + auto invoker_ptr = op_ptr->MakeInvokerPointer(); + + if(op_ptr->IsSupportedArgument(argument_ptr.get())) + { + size_t workspace_sz = op_ptr->GetWorkSpaceSize(argument_ptr.get()); + + SimpleDeviceMem workspace(workspace_sz); + + op_ptr->SetWorkSpacePointer(argument_ptr.get(), workspace.GetDeviceBuffer()); + + invoker_ptr->Run(argument_ptr.get(), StreamConfig{nullptr, false}); + } + + std::cout << "Done" << std::endl; + } + + return 0; +} diff --git a/client_example/19_pool_fwd/max_pool2d_fwd.cpp b/client_example/19_pool/max_pool2d_fwd.cpp similarity index 62% rename from client_example/19_pool_fwd/max_pool2d_fwd.cpp rename to client_example/19_pool/max_pool2d_fwd.cpp index 12ee6192012c8b6114ecade29c0580731e0cb90d..84b818a60fdec6a62ffe027285bdbec854721d5a 100644 --- a/client_example/19_pool_fwd/max_pool2d_fwd.cpp +++ b/client_example/19_pool/max_pool2d_fwd.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #include #include @@ -10,14 +10,18 @@ #include "ck/tensor_operation/gpu/device/device_pool_fwd.hpp" #include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" -#include "ck/library/tensor_operation_instance/gpu/pool2d_fwd.hpp" +#include "ck/library/tensor_operation_instance/gpu/pool3d_fwd.hpp" using InDataType = ck::half_t; using OutDataType = ck::half_t; using IndexDataType = int32_t; -constexpr ck::index_t InOutRank = 4; -constexpr ck::index_t WindowRank = 2; +// We use pool3d to implement pool2d in this example +using InLayout = ck::tensor_layout::convolution::NDHWC; +using OutLayout = ck::tensor_layout::convolution::NDHWC; + +constexpr ck::index_t InOutRank = 5; +constexpr ck::index_t WindowRank = 3; #if 1 constexpr auto ReduceOpId = ck::ReduceTensorOp::MAX; constexpr bool OutputIndex = true; @@ -42,31 +46,66 @@ struct SimpleDeviceMem void* p_mem_; }; +void TransformPool2dparamToPool3d(std::vector& input_lengths, + std::vector& window_lengths, + std::vector& output_lengths, + std::vector& input_stride, + std::vector& output_stride, + std::vector& indices_stride, + std::vector& window_strides, + std::vector& window_dilations, + std::vector& input_left_pads, + std::vector& input_right_pads, + std::vector& pooling_dims) +{ + // NCHW to NCDHW + input_lengths.insert(input_lengths.begin() + 2, 1); + output_lengths.insert(output_lengths.begin() + 2, 1); + input_stride.insert(input_stride.begin() + 2, 0); + output_stride.insert(output_stride.begin() + 2, 0); + indices_stride.insert(indices_stride.begin() + 2, 0); + + // YX to ZYX + window_lengths.insert(window_lengths.begin(), 1); + window_strides.insert(window_strides.begin(), 0); + window_dilations.insert(window_dilations.begin(), 0); + input_left_pads.insert(input_left_pads.begin(), 0); + input_right_pads.insert(input_right_pads.begin(), 0); + + pooling_dims = {2, 3, 4}; +} + int main(int argc, char* argv[]) { - ck::index_t N = 2; - ck::index_t C = 32; - ck::index_t Y = 2; - ck::index_t X = 2; - ck::index_t Hi = 30; - ck::index_t Wi = 30; - ck::index_t window_stride_h = 2; - ck::index_t window_stride_w = 2; - ck::index_t in_left_pad_h = 1; - ck::index_t in_left_pad_w = 1; - ck::index_t in_right_pad_h = 1; - ck::index_t in_right_pad_w = 1; - - ck::index_t Ho = (Hi + in_left_pad_h + in_right_pad_h - Y) / window_stride_h + 1; - ck::index_t Wo = (Wi + in_left_pad_w + in_right_pad_w - X) / window_stride_w + 1; + ck::index_t N = 2; + ck::index_t C = 32; + ck::index_t Y = 2; + ck::index_t X = 2; + ck::index_t Hi = 30; + ck::index_t Wi = 30; + ck::index_t window_stride_h = 2; + ck::index_t window_stride_w = 2; + ck::index_t window_dilation_h = 1; + ck::index_t window_dilation_w = 1; + ck::index_t in_left_pad_h = 1; + ck::index_t in_left_pad_w = 1; + ck::index_t in_right_pad_h = 1; + ck::index_t in_right_pad_w = 1; + + const ck::index_t Ys = (Y - 1) * window_dilation_h + 1; + const ck::index_t Xs = (X - 1) * window_dilation_w + 1; + ck::index_t Ho = (Hi + in_left_pad_h + in_right_pad_h - Ys) / window_stride_h + 1; + ck::index_t Wo = (Wi + in_left_pad_w + in_right_pad_w - Xs) / window_stride_w + 1; // Pool API only support the order of NCHW std::vector in_length = {N, C, Hi, Wi}; std::vector out_length = {N, C, Ho, Wo}; std::vector window_spatial_lengths = {Y, X}; std::vector window_strides = {window_stride_h, window_stride_w}; + std::vector window_dilations = {window_dilation_h, window_dilation_w}; std::vector input_left_pads = {in_left_pad_h, in_left_pad_w}; std::vector input_right_pads = {in_right_pad_h, in_right_pad_w}; + std::vector pooling_dims = {2, 3}; std::size_t in_tensor_size = N * C * Hi * Wi; std::size_t out_tensor_size = N * C * Ho * Wo; @@ -75,6 +114,18 @@ int main(int argc, char* argv[]) std::vector in_tensor_stride = {C * Hi * Wi, 1, Wi * C, C}; std::vector out_tensor_stride = {C * Ho * Wo, 1, Wo * C, C}; + TransformPool2dparamToPool3d(in_length, + window_spatial_lengths, + out_length, + in_tensor_stride, + out_tensor_stride, + out_tensor_stride, + window_strides, + window_dilations, + input_left_pads, + input_right_pads, + pooling_dims); + SimpleDeviceMem in_device_buf(sizeof(InDataType) * in_tensor_size); SimpleDeviceMem out_device_buf(sizeof(OutDataType) * out_tensor_size); SimpleDeviceMem out_indices_device_buf(sizeof(IndexDataType) * out_tensor_size); @@ -84,6 +135,8 @@ int main(int argc, char* argv[]) InDataType, OutDataType, IndexDataType, + InLayout, + OutLayout, ReduceOpId, OutputIndex>; @@ -116,9 +169,10 @@ int main(int argc, char* argv[]) out_tensor_stride, out_tensor_stride, window_strides, + window_dilations, input_left_pads, input_right_pads, - {2, 3}); + pooling_dims); auto invoker_ptr = op_ptr->MakeInvokerPointer(); @@ -175,9 +229,10 @@ int main(int argc, char* argv[]) out_tensor_stride, out_tensor_stride, window_strides, + window_dilations, input_left_pads, input_right_pads, - {2, 3}); + pooling_dims); auto invoker_ptr = op_ptr->MakeInvokerPointer(); diff --git a/client_example/19_pool_fwd/CMakeLists.txt b/client_example/19_pool_fwd/CMakeLists.txt deleted file mode 100644 index 13f9f73c83d55c801fdf4609e13f6b0813cb0c67..0000000000000000000000000000000000000000 --- a/client_example/19_pool_fwd/CMakeLists.txt +++ /dev/null @@ -1,5 +0,0 @@ -add_executable(client_max_pool2d_fwd max_pool2d_fwd.cpp) -target_link_libraries(client_max_pool2d_fwd PRIVATE composable_kernel::device_operations) - -add_executable(client_avg_pool3d_fwd avg_pool3d_fwd.cpp) -target_link_libraries(client_avg_pool3d_fwd PRIVATE composable_kernel::device_operations) \ No newline at end of file diff --git a/client_example/20_image_to_column/CMakeLists.txt b/client_example/20_image_to_column/CMakeLists.txt new file mode 100644 index 0000000000000000000000000000000000000000..80edcd04169cac18785b84847dcb1e23e7d596cd --- /dev/null +++ b/client_example/20_image_to_column/CMakeLists.txt @@ -0,0 +1,2 @@ +add_executable(client_image_to_column image_to_column.cpp) +target_link_libraries(client_image_to_column PRIVATE composable_kernel::device_operations) diff --git a/client_example/20_image_to_column/image_to_column.cpp b/client_example/20_image_to_column/image_to_column.cpp new file mode 100644 index 0000000000000000000000000000000000000000..ace4c1a681b77b52dfeac9f79192c12f851d7e3a --- /dev/null +++ b/client_example/20_image_to_column/image_to_column.cpp @@ -0,0 +1,167 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#include +#include +#include +#include +#include +#include + +#include "ck/ck.hpp" +#include "ck/library/tensor_operation_instance/gpu/image_to_column.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" + +using InDataType = ck::half_t; +using OutDataType = ck::half_t; + +using InLayout = ck::tensor_layout::convolution::GNHWC; + +static constexpr ck::index_t NumDimSpatial = 2; +static constexpr ck::index_t G = 1; +static constexpr ck::index_t N = 32; // batch size +static constexpr ck::index_t C = 32; // input channel (per group) +static constexpr ck::index_t Y = 3; // filter H +static constexpr ck::index_t X = 3; // filter W +static constexpr ck::index_t Hi = 28; // input H +static constexpr ck::index_t Wi = 28; // input W +static constexpr ck::index_t Ho = 28; // output H +static constexpr ck::index_t Wo = 28; // output W + +struct SimpleDeviceMem +{ + SimpleDeviceMem() = delete; + + SimpleDeviceMem(std::size_t mem_size) : p_mem_{} + { + (void)hipMalloc(static_cast(&p_mem_), mem_size); + } + + void* GetDeviceBuffer() { return p_mem_; } + + ~SimpleDeviceMem() { (void)hipFree(p_mem_); } + + void* p_mem_; +}; + +int main() +{ + + std::array in_spatial_lengths{Hi, Wi}; + std::array wei_spatial_lengths{Y, X}; + std::array out_spatial_lengths{Ho, Wo}; + + // We have NHWGC in memory space (G is dummy) + // However, CK's API only accept length and stride with order of GNCHW + // Hence, we need to adjust the order of stride + std::array in_strides{C, Hi * Wi * G * C, 1, Wi * G * C, G * C}; + std::array out_strides{Y * X * C, 1}; + + std::array filter_strides{1, 1}; + std::array filter_dilations{1, 1}; + std::array input_left_pads{1, 1}; + std::array input_right_pads{1, 1}; + + SimpleDeviceMem in(sizeof(InDataType) * N * Hi * Wi * G * C); + SimpleDeviceMem out(sizeof(OutDataType) * N * Ho * Wo * Y * X * C); + + using DeviceOp = ck::tensor_operation::device:: + DeviceImageToColumn; + + // get device op instances + const auto op_ptrs = ck::tensor_operation::device::instance::DeviceOperationInstanceFactory< + DeviceOp>::GetInstances(); + + std::cout << "found " << op_ptrs.size() << " instances" << std::endl; + + std::string best_op_name; + int best_op_id = -1; + float best_avg_time = std::numeric_limits::max(); + float best_gb_per_sec = 0; + + // profile device operation instances + std::cout << "Run all instances and do timing" << std::endl; + + for(int i = 0; i < op_ptrs.size(); ++i) + { + auto& op_ptr = op_ptrs[i]; + auto argument_ptr = op_ptr->MakeArgumentPointer(in.GetDeviceBuffer(), + out.GetDeviceBuffer(), + N, + C, + in_spatial_lengths, + out_spatial_lengths, + wei_spatial_lengths, + in_strides, + out_strides, + filter_strides, + filter_dilations, + input_left_pads, + input_right_pads); + auto invoker_ptr = op_ptr->MakeInvokerPointer(); + std::string op_name = op_ptr->GetTypeString(); + + if(op_ptr->IsSupportedArgument(argument_ptr.get())) + { + float avg_time = invoker_ptr->Run(argument_ptr.get(), StreamConfig{nullptr, true}); + + std::size_t num_bytes = sizeof(InDataType) * N * Hi * Wi * G * C + + sizeof(OutDataType) * N * Ho * Wo * Y * X * C; + + float gb_per_sec = num_bytes / 1.E6 / avg_time; + + std::cout << "Perf: " << std::setw(10) << avg_time << " ms, " << gb_per_sec << " GB/s, " + << op_name << std::endl; + + if(avg_time < best_avg_time) + { + best_op_id = i; + best_op_name = op_name; + best_avg_time = avg_time; + best_gb_per_sec = gb_per_sec; + } + } + else + { + std::cerr << op_name << " does not support this problem" << std::endl; + } + } + + if(best_op_id < 0) + { + std::cerr << "no suitable instance" << std::endl; + return EXIT_FAILURE; + } + + std::cout << "Best Perf: " << std::setw(10) << best_avg_time << " ms, " << best_gb_per_sec + << " GB/s, " << best_op_name << std::endl; + + // run the best intance + { + auto& op_ptr = op_ptrs[best_op_id]; + std::cout << "Run the best instance without timing: " << op_ptr->GetTypeString() + << std::endl; + auto argument_ptr = op_ptr->MakeArgumentPointer(in.GetDeviceBuffer(), + out.GetDeviceBuffer(), + N, + C, + in_spatial_lengths, + out_spatial_lengths, + wei_spatial_lengths, + in_strides, + out_strides, + filter_strides, + filter_dilations, + input_left_pads, + input_right_pads); + + auto invoker_ptr = op_ptr->MakeInvokerPointer(); + + if(op_ptr->IsSupportedArgument(argument_ptr.get())) + { + invoker_ptr->Run(argument_ptr.get(), StreamConfig{nullptr, false}); + } + + std::cout << "Done" << std::endl; + } +} diff --git a/client_example/20_splitk_gemm/CMakeLists.txt b/client_example/20_splitk_gemm/CMakeLists.txt new file mode 100644 index 0000000000000000000000000000000000000000..a60bada473b38f041bc409ef42dc4fabb7ecc57d --- /dev/null +++ b/client_example/20_splitk_gemm/CMakeLists.txt @@ -0,0 +1,2 @@ +add_executable(client_splitK_gemm splitK_gemm_fp16_f8.cpp) +target_link_libraries(client_splitK_gemm PRIVATE composable_kernel::device_operations) diff --git a/client_example/20_splitk_gemm/splitK_gemm_fp16_f8.cpp b/client_example/20_splitk_gemm/splitK_gemm_fp16_f8.cpp new file mode 100644 index 0000000000000000000000000000000000000000..94a57cd029a30d0fc73bae8cb8e43c1d475c95bc --- /dev/null +++ b/client_example/20_splitk_gemm/splitK_gemm_fp16_f8.cpp @@ -0,0 +1,225 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#include +#include +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/device_gemm_splitk.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" + +#include "ck/library/tensor_operation_instance/gpu/gemm_splitk.hpp" + +using F8 = ck::f8_t; +using F16 = ck::half_t; +using F32 = float; + +using Row = ck::tensor_layout::gemm::RowMajor; +using Col = ck::tensor_layout::gemm::ColumnMajor; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; + +using AElementOp = PassThrough; +using BElementOp = PassThrough; +using CElementOp = PassThrough; + +using ADataType = F8; +using BDataType = F16; +using CDataType = F16; + +using ALayout = Row; +using BLayout = Col; +using CLayout = Row; + +struct SimpleDeviceMem +{ + SimpleDeviceMem() = delete; + + SimpleDeviceMem(std::size_t mem_size) : p_mem_{} + { + (void)hipMalloc(static_cast(&p_mem_), mem_size); + } + + void* GetDeviceBuffer() { return p_mem_; } + + ~SimpleDeviceMem() { (void)hipFree(p_mem_); } + + void* p_mem_; +}; + +int main(int argc, char* argv[]) +{ + // GEMM shape + ck::index_t M = 3840; + ck::index_t N = 4096; + ck::index_t K = 4096; + + ck::index_t StrideA = 4096; + ck::index_t StrideB = 4096; + ck::index_t StrideC = 4096; + + ck::index_t KBatch = 1; + + if(argc == 1) + { + // use default case + } + else if(argc == 8) + { + M = std::stoi(argv[1]); + N = std::stoi(argv[2]); + K = std::stoi(argv[3]); + + StrideA = std::stoi(argv[4]); + StrideB = std::stoi(argv[5]); + StrideC = std::stoi(argv[6]); + + KBatch = std::stoi(argv[7]); + } + else + { + printf("arg1 to 7: M, N, K, StrideA, StrideB, StrideC, KBatch\n"); + exit(0); + } + + auto f_matrix_space_size = + [](std::size_t nRow, std::size_t nCol, std::size_t stride, auto layout) { + using Layout = decltype(layout); + + if constexpr(std::is_same::value) + { + return (nRow - 1) * stride + nCol; + } + else + { + return (nCol - 1) * stride + nRow; + } + }; + + SimpleDeviceMem a_device_buf(sizeof(ADataType) * f_matrix_space_size(M, K, StrideA, ALayout{})); + SimpleDeviceMem b_device_buf(sizeof(BDataType) * f_matrix_space_size(K, N, StrideB, BLayout{})); + SimpleDeviceMem c_device_buf(sizeof(CDataType) * f_matrix_space_size(M, N, StrideC, CLayout{})); + + using DeviceOp = ck::tensor_operation::device::DeviceGemmSplitK< + ALayout, + BLayout, + CLayout, + ADataType, + BDataType, + CDataType, + ck::tensor_operation::element_wise::PassThrough, + ck::tensor_operation::element_wise::PassThrough, + ck::tensor_operation::element_wise::PassThrough>; + + // get device op instances + const auto op_ptrs = ck::tensor_operation::device::instance::DeviceOperationInstanceFactory< + DeviceOp>::GetInstances(); + + std::cout << "found " << op_ptrs.size() << " instances" << std::endl; + + const auto a_element_op = AElementOp{}; + const auto b_element_op = BElementOp{}; + const auto c_element_op = CElementOp{}; + + std::string best_op_name; + bool found = false; + int best_op_id = -1; + float best_ave_time = 0; + float best_tflops = 0; + float best_gb_per_sec = 0; + + // profile device operation instances + std::cout << "Run all instances and do timing" << std::endl; + + for(int i = 0; i < op_ptrs.size(); ++i) + { + auto& op_ptr = op_ptrs[i]; + + auto argument_ptr = op_ptr->MakeArgumentPointer(a_device_buf.GetDeviceBuffer(), + b_device_buf.GetDeviceBuffer(), + c_device_buf.GetDeviceBuffer(), + M, + N, + K, + StrideA, + StrideB, + StrideC, + a_element_op, + b_element_op, + c_element_op, + KBatch); + + auto invoker_ptr = op_ptr->MakeInvokerPointer(); + + std::string op_name = op_ptr->GetTypeString(); + + if(op_ptr->IsSupportedArgument(argument_ptr.get())) + { + float ave_time = invoker_ptr->Run(argument_ptr.get(), StreamConfig{nullptr, true}); + + std::size_t flop = std::size_t(2) * M * N * K; + + std::size_t num_btype = + sizeof(ADataType) * M * K + sizeof(BDataType) * K * N + sizeof(CDataType) * M * N; + + float tflops = static_cast(flop) / 1.E9 / ave_time; + + float gb_per_sec = num_btype / 1.E6 / ave_time; + + std::cout << "Perf: " << std::setw(10) << ave_time << " ms, " << tflops << " TFlops, " + << gb_per_sec << " GB/s, " << op_name << std::endl; + + if(tflops > best_tflops) + { + found = true; + best_op_id = i; + best_op_name = op_name; + best_tflops = tflops; + best_ave_time = ave_time; + best_gb_per_sec = gb_per_sec; + } + } + else + { + std::cout << op_name << " does not support this problem" << std::endl; + } + } + + std::cout << "Best Perf: " << best_ave_time << " ms, " << best_tflops << " TFlops, " + << best_gb_per_sec << " GB/s, " << best_op_name << std::endl; + + // run the best intance + { + auto& op_ptr = op_ptrs[best_op_id]; + + std::cout << "Run the best instance without timing: " << op_ptr->GetTypeString() + << std::endl; + + auto argument_ptr = op_ptr->MakeArgumentPointer(a_device_buf.GetDeviceBuffer(), + b_device_buf.GetDeviceBuffer(), + c_device_buf.GetDeviceBuffer(), + M, + N, + K, + StrideA, + StrideB, + StrideC, + a_element_op, + b_element_op, + c_element_op, + KBatch); + + auto invoker_ptr = op_ptr->MakeInvokerPointer(); + + if(op_ptr->IsSupportedArgument(argument_ptr.get())) + { + invoker_ptr->Run(argument_ptr.get(), StreamConfig{nullptr, false}); + } + + std::cout << "Done" << std::endl; + } + + return 0; +} diff --git a/client_example/21_grouped_gemm_bias/CMakeLists.txt b/client_example/21_grouped_gemm_bias/CMakeLists.txt new file mode 100644 index 0000000000000000000000000000000000000000..a2abd15731fcf97af99646d5dced7aa944a974aa --- /dev/null +++ b/client_example/21_grouped_gemm_bias/CMakeLists.txt @@ -0,0 +1,2 @@ +add_executable(client_grouped_gemm_fixed_nk_bias_fp16 grouped_gemm_fixed_nk_bias_fp16.cpp) +target_link_libraries(client_grouped_gemm_fixed_nk_bias_fp16 PRIVATE composable_kernel::device_operations) diff --git a/client_example/21_grouped_gemm_bias/grouped_gemm_fixed_nk_bias_fp16.cpp b/client_example/21_grouped_gemm_bias/grouped_gemm_fixed_nk_bias_fp16.cpp new file mode 100644 index 0000000000000000000000000000000000000000..3b6dd9a2a92ad75e4427d8e825a18df307cfea4a --- /dev/null +++ b/client_example/21_grouped_gemm_bias/grouped_gemm_fixed_nk_bias_fp16.cpp @@ -0,0 +1,244 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#include +#include +#include +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/device_grouped_gemm_fixed_nk.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" + +#include "ck/library/tensor_operation_instance/gpu/grouped_gemm_bias.hpp" + +using F16 = ck::half_t; +using F32 = float; + +using Row = ck::tensor_layout::gemm::RowMajor; +using Col = ck::tensor_layout::gemm::ColumnMajor; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; +using Add = ck::tensor_operation::element_wise::Add; + +using ADataType = F16; +using BDataType = F16; +using D0DataType = F32; +using DsDataType = ck::Tuple; +using EDataType = F32; + +using ALayout = Row; +using BLayout = Row; +using D0Layout = Row; +using DsLayout = ck::Tuple; +using ELayout = Row; + +using AElementOp = PassThrough; +using BElementOp = PassThrough; +using CDEElementOp = Add; + +struct SimpleDeviceMem +{ + SimpleDeviceMem() = delete; + + SimpleDeviceMem(std::size_t mem_size) : p_mem_{} + { + (void)hipMalloc(static_cast(&p_mem_), mem_size); + } + + void* GetDeviceBuffer() { return p_mem_; } + + ~SimpleDeviceMem() { (void)hipFree(p_mem_); } + + void* p_mem_; +}; + +int main() +{ + std::vector Ms, Ns, Ks, StrideAs, StrideBs, StrideEs; + + int sum_of_m = 0; + + Ms = {167, 183, 177, 181, 153, 139, 156, 173, 163, 150, 204, 184, 168, 156, 168, 148}; + + int group_count = Ms.size(); + + for(int i = 0; i < group_count; ++i) + { + Ns.push_back(768); + Ks.push_back(4608); + + StrideAs.push_back(std::is_same::value ? Ks[i] : Ms[i]); + StrideBs.push_back(std::is_same::value ? Ns[i] : Ks[i]); + StrideEs.push_back(std::is_same::value ? Ns[i] : Ms[i]); + + sum_of_m += Ms[i]; + } + + auto f_matrix_space_size = + [](std::size_t nRow, std::size_t nCol, std::size_t stride, auto layout) { + using Layout = decltype(layout); + + if constexpr(std::is_same::value) + { + return (nRow - 1) * stride + nCol; + } + else + { + return (nCol - 1) * stride + nRow; + } + }; + + std::vector a_dev_bufs, b_dev_bufs, d0_dev_bufs, e_dev_bufs; + + a_dev_bufs.reserve(group_count); + b_dev_bufs.reserve(group_count); + d0_dev_bufs.reserve(group_count); + e_dev_bufs.reserve(group_count); + + std::vector p_e; + + p_e.reserve(group_count); + + std::vector gemm_descs; + + gemm_descs.reserve(group_count); + + std::vector> + grouped_gemm_kernel_args_; + grouped_gemm_kernel_args_.reserve(group_count); + + for(int i = 0; i < group_count; ++i) + { + a_dev_bufs.emplace_back(sizeof(ADataType) * + f_matrix_space_size(Ms[i], Ks[i], StrideAs[i], ALayout{})); + b_dev_bufs.emplace_back(sizeof(BDataType) * + f_matrix_space_size(Ks[i], Ns[i], StrideBs[i], BLayout{})); + d0_dev_bufs.emplace_back(sizeof(D0DataType) * + f_matrix_space_size(Ms[i], Ns[i], 0, D0Layout{})); + e_dev_bufs.emplace_back(sizeof(EDataType) * + f_matrix_space_size(Ms[i], Ns[i], StrideEs[i], ELayout{})); + + gemm_descs.push_back({sum_of_m, Ns[i], Ks[i], 1, StrideBs[i], 1, {0}}); + + p_e.push_back(e_dev_bufs[i].GetDeviceBuffer()); + + grouped_gemm_kernel_args_.push_back( + {a_dev_bufs[i].GetDeviceBuffer(), + b_dev_bufs[i].GetDeviceBuffer(), + std::array{d0_dev_bufs[i].GetDeviceBuffer()}, + e_dev_bufs[i].GetDeviceBuffer(), + Ms[i], + Ns[i], + Ks[i], + StrideAs[i], + StrideBs[i], + std::array{0}, + StrideEs[i]}); + } + + using DeviceOp = ck::tensor_operation::device::DeviceGroupedGemmFixedNK; + + // get device op instances + const auto op_ptrs = ck::tensor_operation::device::instance::DeviceOperationInstanceFactory< + DeviceOp>::GetInstances(); + + std::cout << "found " << op_ptrs.size() << " instances" << std::endl; + + const auto a_element_op = AElementOp{}; + const auto b_element_op = BElementOp{}; + const auto cde_element_op = CDEElementOp{}; + + std::string best_op_name; + bool found = false; + int best_op_id = -1; + float best_ave_time = 0; + float best_tflops = 0; + float best_gb_per_sec = 0; + + // profile device operation instances + std::cout << "Run all instances and do timing" << std::endl; + + std::vector p_a = {}, p_b = {}; + std::vector> p_ds = {}; + + for(int i = 0; i < op_ptrs.size(); ++i) + { + auto& op_ptr = op_ptrs[i]; + + auto argument_ptr = op_ptr->MakeArgumentPointer( + p_a, p_b, p_ds, p_e, gemm_descs, a_element_op, b_element_op, cde_element_op); + + auto invoker_ptr = op_ptr->MakeInvokerPointer(); + + SimpleDeviceMem grouped_gemm_kernel_args_dev( + op_ptr->GetDeviceKernelArgSize(argument_ptr.get())); + + SimpleDeviceMem grouped_gemm_workspace_dev(op_ptr->GetWorkSpaceSize(argument_ptr.get())); + + std::string op_name = op_ptr->GetTypeString(); + + hipGetErrorString(hipMemcpy(grouped_gemm_kernel_args_dev.GetDeviceBuffer(), + grouped_gemm_kernel_args_.data(), + op_ptr->GetDeviceKernelArgSize(argument_ptr.get()), + hipMemcpyHostToDevice)); + + op_ptr->SetWorkSpacePointer(argument_ptr.get(), + grouped_gemm_workspace_dev.GetDeviceBuffer()); + + op_ptr->SetDeviceKernelArgs(argument_ptr.get(), + grouped_gemm_kernel_args_dev.GetDeviceBuffer()); + + op_ptr->SetKBatch(argument_ptr.get(), 2); + + if(op_ptr->IsSupportedArgument(argument_ptr.get())) + { + float ave_time = invoker_ptr->Run(argument_ptr.get(), StreamConfig{nullptr, true}); + + std::size_t flop = 0, num_btype = 0; + for(std::size_t j = 0; j < gemm_descs.size(); ++j) + { + flop += std::size_t(2) * Ms[j] * Ns[j] * Ks[j]; + + num_btype += sizeof(ADataType) * Ms[j] * Ks[j] + sizeof(BDataType) * Ks[j] * Ns[j] + + sizeof(EDataType) * Ms[j] * Ns[j]; + } + + float tflops = static_cast(flop) / 1.E9 / ave_time; + float gb_per_sec = num_btype / 1.E6 / ave_time; + + std::cout << "Perf: " << std::setw(10) << ave_time << " ms, " << tflops << " TFlops, " + << gb_per_sec << " GB/s, " << op_name << std::endl; + + if(tflops > best_tflops) + { + found = true; + best_op_id = i; + best_op_name = op_name; + best_tflops = tflops; + best_ave_time = ave_time; + best_gb_per_sec = gb_per_sec; + } + } + else + { + std::cout << op_name << " does not support this problem" << std::endl; + } + } + + std::cout << "Best Perf: " << best_ave_time << " ms, " << best_tflops << " TFlops, " + << best_gb_per_sec << " GB/s, " << best_op_name << std::endl; + + return 0; +} diff --git a/client_example/CMakeLists.txt b/client_example/CMakeLists.txt index 14c066e4a21ca145eb5ddff3e28f6bd3a45b2449..eb793b3cbd6e693fd8cf94ef3489c6ac3eb915c7 100644 --- a/client_example/CMakeLists.txt +++ b/client_example/CMakeLists.txt @@ -2,7 +2,53 @@ cmake_minimum_required(VERSION 3.15) project(ck_app) add_compile_options(-std=c++17) -find_package(composable_kernel 1.0.0 COMPONENTS device_operations) +if (DTYPES) + add_definitions(-DDTYPES) + if (DTYPES MATCHES "int8") + add_definitions(-DCK_ENABLE_INT8) + if(NOT DEFINED ${CK_ENABLE_INT8}) + set(CK_ENABLE_INT8 "ON") + endif() + endif() + if (DTYPES MATCHES "fp8") + add_definitions(-DCK_ENABLE_FP8) + if(NOT DEFINED ${CK_ENABLE_FP8}) + set(CK_ENABLE_FP8 "ON") + endif() + endif() + if (DTYPES MATCHES "fp16") + add_definitions(-DCK_ENABLE_FP16) + if(NOT DEFINED ${CK_ENABLE_FP16}) + set(CK_ENABLE_FP16 "ON") + endif() + endif() + if (DTYPES MATCHES "fp32") + add_definitions(-DCK_ENABLE_FP32) + if(NOT DEFINED ${CK_ENABLE_FP32}) + set(CK_ENABLE_FP32 "ON") + endif() + endif() + if (DTYPES MATCHES "fp64") + add_definitions(-DCK_ENABLE_FP64) + if(NOT DEFINED ${CK_ENABLE_FP64}) + set(CK_ENABLE_FP64 "ON") + endif() + endif() + if (DTYPES MATCHES "bf16") + add_definitions(-DCK_ENABLE_BF16) + if(NOT DEFINED ${CK_ENABLE_BF16}) + set(CK_ENABLE_BF16 "ON") + endif() + endif() + message("DTYPES macro set to ${DTYPES}") +else() + add_definitions(-DCK_ENABLE_INT8 -DCK_ENABLE_FP8 -DCK_ENABLE_FP16 -DCK_ENABLE_FP32 -DCK_ENABLE_FP64 -DCK_ENABLE_BF16) + if(NOT DEFINED ${CK_ENABLE_ALL_DTYPES}) + set(CK_ENABLE_ALL_DTYPES "ON") + endif() +endif() + +find_package(composable_kernel COMPONENTS device_operations) find_package(hip REQUIRED PATHS /opt/rocm) message(STATUS "Build with HIP ${hip_VERSION}") diff --git a/cmake/DoxygenDoc.cmake b/cmake/DoxygenDoc.cmake index 2e3669fcdf48fde87c23a84c2493ce26aecda7af..c91308b5bb6db3cac55628c2bf5123f261cc838d 100644 --- a/cmake/DoxygenDoc.cmake +++ b/cmake/DoxygenDoc.cmake @@ -309,6 +309,8 @@ XML_OUTPUT XML_PROGRAMLISTING ) +set(WARN_AS_ERROR YES) + set(DOXYGEN_CONFIG_FILE "${CMAKE_CURRENT_BINARY_DIR}/doxygen/doxygen.conf" CACHE PATH "Path to generated doxygen configuration file") function(add_doxygen_doc) diff --git a/cmake/EnableCompilerWarnings.cmake b/cmake/EnableCompilerWarnings.cmake index 369cd0b54c1fac60b061cdafdb0b02b819a8accd..66139cc710992df07151b85537c28f483e2ef452 100644 --- a/cmake/EnableCompilerWarnings.cmake +++ b/cmake/EnableCompilerWarnings.cmake @@ -67,6 +67,7 @@ else() -Wunused -Wno-reserved-identifier -Werror + -Wno-option-ignored -Wsign-compare -Wno-extra-semi-stmt ) diff --git a/docs/API_Reference_Guide.rst b/docs/API_Reference_Guide.rst index b59c6e302692cd6b9d4035ed27d5901e790a0233..f21d43c5939ab07d3e9b6cc2c6c361a782b13ff1 100644 --- a/docs/API_Reference_Guide.rst +++ b/docs/API_Reference_Guide.rst @@ -7,8 +7,8 @@ API Reference Guide Introduction ================= -This document contains details of the APIs for the Composable Kernel (CK) library and introduces some of the key design -principles that are used to write new classes that extend CK functionality. +This document contains details of the APIs for the Composable Kernel (CK) library and introduces +some of the key design principles that are used to write new classes that extend CK functionality. ================= Using CK API @@ -30,8 +30,8 @@ DeviceMem Kernels For Flashattention --------------------------- -The Flashattention algorithm is defined in :cite:t:`dao2022flashattention`. This sections lists the classes that are -used in the CK GPU implementation of Flashattention. +The Flashattention algorithm is defined in :cite:t:`dao2022flashattention`. This sections lists +the classes that are used in the CK GPU implementation of Flashattention. **Gridwise classes** diff --git a/docs/Contributors_Guide.rst b/docs/Contributors_Guide.rst index b2ddff398ce8efcd0c5d8be275d9dd09315d4349..41cb8f19156b389930a886be82aa5474b5465a37 100644 --- a/docs/Contributors_Guide.rst +++ b/docs/Contributors_Guide.rst @@ -2,7 +2,101 @@ Contributor's Guide =================== -Pull-request guidelines -======================= +This chapter explains how to get started contributing to the Composable Kernel project and what are +the contributing rules. -[TODO] +Getting started +=============== + +#. **Documentation:** Before contributing to the library, familiarize yourself with the + `Composable Kernel User Guide `_. + It provides insight into the core concepts, environment configuration, and steps to obtain or + build the library. You can also find some of this information in the + `README file `_ + on the project's GitHub page. +#. **Additional reading:** We also recommend reading a `blog post + `_ + from the AMD Community portal. It offers a deeper understanding of the library's objectives and + showcases its performance capabilities. +#. **General information:** For broader information about AMD products, consider exploring the + `AMD Developer Central portal `_. + +How do I contribute +=================== + +We deeply value contributions from our users. You can make an impact by reporting issues or +proposing code enhancements through pull requests. + +Reporting issues +---------------- + +We use `Github issues `_ +to track public bugs and enhancement requests. + +If you encounter an issue with the library, please check if the problem has already been +reported by searching existing issues on GitHub. If your issue seems unique, please submit a new +issue. All reported issues must include: + +* A comprehensive description of the problem, including: + + * What did you observe? + * Why do you think it is a bug (if it seems like one)? + * What did you expect to happen? What would indicate the resolution of the problem? + * Are there any known workarounds? + +* Your configuration details, including: + + * Which GPU are you using? + * Which OS version are you on? + * Which ROCm version are you using? + * Are you using a Docker image? If so, which one? + +* Steps to reproduce the issue, including: + + * What actions trigger the issue? What are the reproduction steps? + + * If you build the library from scratch, what CMake command did you use? + + * How frequently does this issue happen? Does it reproduce every time? Or is it a sporadic issue? + +Before sumbitting any issue, ensure you have addressed all relevant questions from the checklist. + +Creating Pull Requests +---------------------- + +You can submit `Pull Requests (PR) on GitHub +`_. + +All contributors are required to develop their changes on a separate branch and then create a +pull requrest to merge their changes into the `develop` branch, which is the default +development branch in the Composable Kernel project. All external contributors must use their own +forks of the project to develop their changes. + +When submitting a Pull Request you should: + +* Describe the change providing information about the motivation for the change and a general + description of all code modifications. + +* Verify and test the change: + + * Run any relevant existing tests. + * Write new tests if added functionality is not covered by current tests. + +* Ensure your changes align with the coding style defined in the ``.clang-format`` file located in + the project's root directory. We leverage `pre-commit` to run `clang-format` automatically. We + highly recommend contributors utilize this method to maintain consistent code formatting. + Instructions on setting up `pre-commit` can be found in the project's + `README file `_ + +* Link your PR to any related issues: + + * If there is an issue that is resolved by your change, please provide a link to the issue in + the description of your pull request. + +* For larger contributions, structure your change into a sequence of smaller, focused commits, each + addressing a particular aspect or fix. + +Following the above guidelines ensures a seamless review process and faster assistance from our +end. + +Thank you for your commitment to enhancing the Composable Kernel project! We look forward to collaborating with you. diff --git a/docs/Supported_Primitives_Guide.rst b/docs/Supported_Primitives_Guide.rst index 4c3adf67d7119e22a622373ab8d1dccb4d024e18..3462283d90d62f1a181a1cc26d04a36f9d6f4fff 100644 --- a/docs/Supported_Primitives_Guide.rst +++ b/docs/Supported_Primitives_Guide.rst @@ -2,15 +2,16 @@ Supported Primitives Guide ========================== -This document contains details of supported primitives in Composable Kernel (CK). In contrast to the API Reference -Guide, the Supported Primitives Guide is an introduction to the math which underpins the algorithms implemented in CK. +This document contains details of supported primitives in Composable Kernel (CK). In contrast to the +API Reference Guide, the Supported Primitives Guide is an introduction to the math which underpins +the algorithms implemented in CK. ------------ Softmax ------------ -For vectors :math:`x^{(1)}, x^{(2)}, \ldots, x^{(T)}` of size :math:`B` we can decompose the softmax of concatenated -:math:`x = [ x^{(1)}\ | \ \ldots \ | \ x^{(T)} ]` as, +For vectors :math:`x^{(1)}, x^{(2)}, \ldots, x^{(T)}` of size :math:`B` we can decompose the +softmax of concatenated :math:`x = [ x^{(1)}\ | \ \ldots \ | \ x^{(T)} ]` as, .. math:: :nowrap: @@ -25,8 +26,8 @@ For vectors :math:`x^{(1)}, x^{(2)}, \ldots, x^{(T)}` of size :math:`B` we can d where :math:`f(x^{(j)}) = \exp( x^{(j)} - m(x^{(j)}) )` is of size :math:`B` and :math:`z(x^{(j)}) = f(x_1^{(j)})+ \ldots+ f(x_B^{(j)})` is a scalar. -For a matrix :math:`X` composed of :math:`T_r \times T_c` tiles, :math:`X_{ij}`, of size :math:`B_r \times B_c` we can -compute the row-wise softmax as follows. +For a matrix :math:`X` composed of :math:`T_r \times T_c` tiles, :math:`X_{ij}`, of size +:math:`B_r \times B_c` we can compute the row-wise softmax as follows. For :math:`j` from :math:`1` to :math:`T_c`, and :math:`i` from :math:`1` to :math:`T_r` calculate, diff --git a/docs/dockerhub.rst b/docs/dockerhub.rst index b51226cfebe6da27988ec2685eb334f5bc900ff9..66ec91096e01b9ebfbe78559b9a168711a7b4206 100644 --- a/docs/dockerhub.rst +++ b/docs/dockerhub.rst @@ -1,27 +1,27 @@ =================== -CK docker hub +CK Docker Hub =================== -`Docker hub `_ - ------------------------------------- Why do I need this? ------------------------------------- -To make our lives easier and bring Composable Kernel dependencies together, we recommend using docker images. +To make our lives easier and bring Composable Kernel dependencies together, we recommend using +docker images that can be found on `Docker Hub `_. ------------------------------------- So what is Composable Kernel? ------------------------------------- -Composable Kernel (CK) library aims to provide a programming model for writing performance critical kernels for machine learning workloads across multiple architectures including GPUs, CPUs, etc, through general purpose kernel languages, like HIP C++. +Composable Kernel (CK) library aims to provide a programming model for writing performance critical +kernels for machine learning workloads across multiple architectures including GPUs, CPUs, etc, +through general purpose kernel languages, like HIP C++. To get the CK library:: git clone https://github.com/ROCmSoftwarePlatform/composable_kernel.git - run a docker container:: docker run \ @@ -30,7 +30,7 @@ run a docker container:: --group-add sudo \ -w /root/workspace \ -v ${PATH_TO_LOCAL_WORKSPACE}:/root/workspace \ - rocm/composable_kernel:ck_ub20.04_rocm5.3_release \ + rocm/composable_kernel:ck_ub20.04_rocm5.6 \ /bin/bash and build the CK:: @@ -58,7 +58,9 @@ We can also run specific examples or tests like:: ./bin/example_gemm_xdl_fp16 ./bin/test_gemm_fp16 -For more details visit `CK github repo `_, `CK examples `_, `even more CK examples `_. +For more details visit `CK github repository `_, +`CK examples `_, +`even more CK examples `_. ------------------------------------- And what is inside? @@ -74,12 +76,11 @@ The docker images have everything you need for running CK including: Which image is right for me? ------------------------------------- -Let's take a look at the image naming, for example "ck_ub20.04_rocm5.4_release". The image specs are: +Let's take a look at the image naming, for example ``ck_ub20.04_rocm5.6``. The image specs are: -* "ck" - made for running Composable Kernel -* "ub20.04" - based on Ubuntu 20.04 -* "rocm5.4" - ROCm platform version 5.4 -* "release" - compiler version is release +* ``ck`` - made for running Composable Kernel; +* ``ub20.04`` - based on Ubuntu 20.04; +* ``rocm5.6`` - ROCm platform version 5.6. So just pick the right image for your project dependencies and you're all set. @@ -87,7 +88,9 @@ So just pick the right image for your project dependencies and you're all set. DIY starts here ------------------------------------- -If you need to customize a docker image or just can't stop tinkering, feel free to adjust the `Dockerfile `_ for your needs. +If you need to customize a docker image or just can't stop tinkering, feel free to adjust the +`Dockerfile `_ +for your needs. ------------------------------------- License diff --git a/docs/index.rst b/docs/index.rst index f4e66c1b51f4528fa47057116a8f457680b3d972..51c0c862ae3e3e8b314aa37bd63b6d762b895607 100644 --- a/docs/index.rst +++ b/docs/index.rst @@ -12,12 +12,15 @@ This document contains instructions for installing, using, and contributing to C Methodology ----------- -Composable Kernel (CK) library aims to provide a programming model for writing performance critical kernels for machine learning workloads across multiple architectures including GPUs, CPUs, etc, through general purpose kernel languages, like HIP C++. +Composable Kernel (CK) library aims to provide a programming model for writing performance critical +kernels for machine learning workloads across multiple architectures including GPUs, CPUs, etc, +through general purpose kernel languages, like HIP C++. CK utilizes two concepts to achieve performance portability and code maintainability: * A tile-based programming model -* Algorithm complexity reduction for complex ML operators, using innovative technique we call "Tensor Coordinate Transformation". +* Algorithm complexity reduction for complex ML operators, using innovative technique we call + "Tensor Coordinate Transformation". .. image:: data/ck_component.png :alt: CK Components diff --git a/docs/sphinx/requirements.in b/docs/sphinx/requirements.in index 4bdf41b959975f9b7f6ba38d20d5070f83e4da1e..a12a00a2b21ec12d8d6f41d0838309350d2eefd1 100644 --- a/docs/sphinx/requirements.in +++ b/docs/sphinx/requirements.in @@ -1,2 +1,2 @@ -rocm-docs-core==0.10.3 +rocm-docs-core>=0.20.0 sphinxcontrib-bibtex==2.5.0 diff --git a/docs/sphinx/requirements.txt b/docs/sphinx/requirements.txt index 097acba2257d3c3f1c03d801ebeebdccd0dd120f..d5f67eeb585b3b457ee49bf5b1b46468f76e742a 100644 --- a/docs/sphinx/requirements.txt +++ b/docs/sphinx/requirements.txt @@ -38,6 +38,8 @@ docutils==0.16 # pydata-sphinx-theme # sphinx # sphinxcontrib-bibtex +fastjsonschema==2.18.0 + # via rocm-docs-core gitdb==4.0.10 # via gitpython gitpython==3.1.31 @@ -46,20 +48,12 @@ idna==3.4 # via requests imagesize==1.4.1 # via sphinx -importlib-metadata==6.0.0 - # via - # sphinx - # sphinxcontrib-bibtex -importlib-resources==5.12.0 - # via rocm-docs-core jinja2==3.1.2 # via # myst-parser # sphinx latexcodec==2.0.1 # via pybtex -linkify-it-py==1.0.3 - # via myst-parser markdown-it-py==2.2.0 # via # mdit-py-plugins @@ -70,7 +64,7 @@ mdit-py-plugins==0.3.5 # via myst-parser mdurl==0.1.2 # via markdown-it-py -myst-parser[linkify]==1.0.0 +myst-parser==1.0.0 # via rocm-docs-core packaging==23.0 # via @@ -99,18 +93,17 @@ pyjwt[crypto]==2.6.0 # via pygithub pynacl==1.5.0 # via pygithub -pytz==2023.3 - # via babel pyyaml==6.0 # via # myst-parser # pybtex + # rocm-docs-core # sphinx-external-toc requests==2.28.2 # via # pygithub # sphinx -rocm-docs-core==0.10.3 +rocm-docs-core>=0.20.0 # via -r requirements.in six==1.16.0 # via @@ -160,13 +153,7 @@ sphinxcontrib-serializinghtml==1.1.5 # via sphinx typing-extensions==4.5.0 # via pydata-sphinx-theme -uc-micro-py==1.0.1 - # via linkify-it-py urllib3==1.26.15 # via requests wrapt==1.15.0 # via deprecated -zipp==3.15.0 - # via - # importlib-metadata - # importlib-resources diff --git a/docs/tutorial_hello_world.rst b/docs/tutorial_hello_world.rst index b8fd094654cba671b5082d8052b7dd3d6671867a..bfb197e085a2a986ae50aea6b6eeebd8d203eab1 100644 --- a/docs/tutorial_hello_world.rst +++ b/docs/tutorial_hello_world.rst @@ -6,15 +6,26 @@ CK Hello world Motivation ------------------------------------- -This tutorial is aimed at engineers dealing with artificial intelligence and machine learning who would like to optimize their pipelines and squeeze every performance drop by adding Composable Kernel (CK) library to their projects. We would like to make the CK library approachable so the tutorial is not based on the latest release and doesn't have all the bleeding edge features, but it will be reproducible now and forever. +This tutorial is aimed at engineers dealing with artificial intelligence and machine learning who +would like to optimize their pipelines and squeeze every performance drop by adding Composable +Kernel (CK) library to their projects. We would like to make the CK library approachable so +the tutorial is not based on the latest release and doesn't have all the bleeding edge features, +but it will be reproducible now and forever. -During this tutorial we will have an introduction to the CK library, we will build it and run some examples and tests, so to say we will run a "Hello world" example. In future tutorials we will go in depth and breadth and get familiar with other tools and ways to integrate CK into your project. +During this tutorial we will have an introduction to the CK library, we will build it and run some +examples and tests, so to say we will run a "Hello world" example. In future tutorials we will go +in depth and breadth and get familiar with other tools and ways to integrate CK into your project. ------------------------------------- Description ------------------------------------- -Modern AI technology solves more and more problems in all imaginable fields, but crafting fast and efficient workflows is still challenging. CK is one of the tools to make AI heavy lifting as fast and efficient as possible. CK is a collection of optimized AI operator kernels and tools to create new ones. The library has components required for majority of modern neural networks architectures including matrix multiplication, convolution, contraction, reduction, attention modules, variety of activation functions, fused operators and many more. +Modern AI technology solves more and more problems in all imaginable fields, but crafting fast and +efficient workflows is still challenging. CK is one of the tools to make AI heavy lifting as fast +and efficient as possible. CK is a collection of optimized AI operator kernels and tools to create +new ones. The library has components required for majority of modern neural networks architectures +including matrix multiplication, convolution, contraction, reduction, attention modules, variety of +activation functions, fused operators and many more. So how do we (almost) reach the speed of light? CK acceleration abilities are based on: @@ -24,15 +35,18 @@ So how do we (almost) reach the speed of light? CK acceleration abilities are ba * Hardware acceleration use. * Support of low precision data types including fp16, bf16, int8 and int4. -If you are excited and need more technical details and benchmarking results - read this awesome `blog post `_. +If you are excited and need more technical details and benchmarking results - read this awesome +`blog post `_. -For more details visit our `github repo `_. +For more details visit our `github repository `_. ------------------------------------- Hardware targets ------------------------------------- -CK library fully supports "gfx908" and "gfx90a" GPU architectures and only some operators are supported for "gfx1030". Let's check the hardware you have at hand and decide on the target GPU architecture +CK library fully supports `gfx908` and `gfx90a` GPU architectures and only some operators are +supported for `gfx1030`. Let's check the hardware you have at hand and decide on the target +GPU architecture. ========== ========= GPU Target AMD GPU @@ -42,7 +56,8 @@ gfx90a Radeon Instinct MI210, MI250, MI250X gfx1030 Radeon PRO V620, W6800, W6800X, W6800X Duo, W6900X, RX 6800, RX 6800 XT, RX 6900 XT, RX 6900 XTX, RX 6950 XT ========== ========= -There are also `cloud options `_ you can find if you don't have an AMD GPU at hand. +There are also `cloud options `_ you can find if +you don't have an AMD GPU at hand. ------------------------------------- Build the library @@ -54,9 +69,13 @@ First let's clone the library and rebase to the tested version:: cd composable_kernel/ git checkout tutorial_hello_world -To make our lives easier we prepared `docker images `_ with all the necessary dependencies. Pick the right image and create a container. In this tutorial we use "rocm/composable_kernel:ck_ub20.04_rocm5.3_release" image, it is based on Ubuntu 20.04, ROCm v5.3, compiler release version. +To make our lives easier we prepared +`docker images `_ with all the necessary +dependencies. Pick the right image and create a container. In this tutorial we use +``rocm/composable_kernel:ck_ub20.04_rocm5.6`` image, it is based on Ubuntu 20.04 and +ROCm v5.6. -If your current folder is ${HOME}, start the docker container with:: +If your current folder is ``${HOME}``, start the docker container with:: docker run \ -it \ @@ -64,20 +83,23 @@ If your current folder is ${HOME}, start the docker container with:: --group-add sudo \ -w /root/workspace \ -v ${HOME}:/root/workspace \ - rocm/composable_kernel:ck_ub20.04_rocm5.3_release \ + rocm/composable_kernel:ck_ub20.04_rocm5.6 \ /bin/bash -If your current folder is different from ${HOME}, adjust the line `-v ${HOME}:/root/workspace` to fit your folder structure. +If your current folder is different from ``${HOME}``, adjust the line ``-v ${HOME}:/root/workspace`` +to fit your folder structure. -Inside the docker container current folder is "~/workspace", library path is "~/workspace/composable_kernel", navigate to the library:: +Inside the docker container current folder is ``~/workspace``, library path is +``~/workspace/composable_kernel``, navigate to the library:: cd composable_kernel/ -Create and go to the "build" directory:: +Create and go to the ``build`` directory:: mkdir build && cd build -In the previous section we talked about target GPU architecture. Once you decide which one is right for you, run cmake using the right GPU_TARGETS flag:: +In the previous section we talked about target GPU architecture. Once you decide which one is right +for you, run CMake using the right ``GPU_TARGETS`` flag:: cmake \ -D CMAKE_PREFIX_PATH=/opt/rocm \ @@ -87,7 +109,7 @@ In the previous section we talked about target GPU architecture. Once you decide -D BUILD_DEV=OFF \ -D GPU_TARGETS="gfx908;gfx90a;gfx1030" .. -If everything went well the cmake run will end up with:: +If everything went well the CMake run will end up with:: -- Configuring done -- Generating done @@ -118,9 +140,12 @@ We can also run them separately, here is a separate example execution:: ./bin/example_gemm_xdl_fp16 1 1 1 -The arguments "1 1 1" mean that we want to run this example in the mode: verify results with CPU, initialize matrices with integers and benchmark the kernel execution. You can play around with these parameters and see how output and execution results change. +The arguments ``1 1 1`` mean that we want to run this example in the mode: verify results with CPU, +initialize matrices with integers and benchmark the kernel execution. You can play around with +these parameters and see how output and execution results change. -If everything goes well and you have a device based on gfx908 or gfx90a architecture you should see something like:: +If everything goes well and you have a device based on `gfx908` or `gfx90a` architecture you should see +something like:: a_m_k: dim 2, lengths {3840, 4096}, strides {4096, 1} b_k_n: dim 2, lengths {4096, 4096}, strides {1, 4096} @@ -130,14 +155,15 @@ If everything goes well and you have a device based on gfx908 or gfx90a architec Start running 10 times... Perf: 1.10017 ms, 117.117 TFlops, 87.6854 GB/s, DeviceGemmXdl<256, 256, 128, 4, 8, 32, 32, 4, 2> NumPrefetch: 1, LoopScheduler: Default, PipelineVersion: v1 -Meanwhile, running it on a gfx1030 device should result in:: +Meanwhile, running it on a `gfx1030` device should result in:: a_m_k: dim 2, lengths {3840, 4096}, strides {4096, 1} b_k_n: dim 2, lengths {4096, 4096}, strides {1, 4096} c_m_n: dim 2, lengths {3840, 4096}, strides {4096, 1} DeviceGemmXdl<256, 256, 128, 4, 8, 32, 32, 4, 2> NumPrefetch: 1, LoopScheduler: Default, PipelineVersion: v1 does not support this problem -But don't panic, some of the operators are supported on gfx1030 architecture, so you can run a separate example like:: +But don't panic, some of the operators are supported on `gfx1030` architecture, so you can run a +separate example like:: ./bin/example_gemm_dl_fp16 1 1 1 @@ -154,7 +180,14 @@ and it should result in something nice similar to:: Start running 10 times... Perf: 3.65695 ms, 35.234 TFlops, 26.3797 GB/s, DeviceGemmDl<256, 128, 128, 16, 2, 4, 4, 1> -Or we can run a separate test:: +.. note:: + + There was a new CMake flag ``DL_KERNELS`` added in the latest versions of CK. If you use one of + the newest versions of the library and do not see the above results when running + ``example_gemm_dl_fp16``, it might be necessary to add ``-D DL_KERNELS=ON`` to your CMake command + in order to build the operators supported on the `gfx1030` architecture. + +We can also run a separate test:: ctest -R test_gemm_fp16 @@ -169,6 +202,9 @@ If everything goes well you should see something like:: Summary ----------- -In this tutorial we took the first look at the Composable Kernel library, built it on your system and ran some examples and tests. Stay tuned, in the next tutorial we will run kernels with different configs to find out the best one for your hardware and task. +In this tutorial we took the first look at the Composable Kernel library, built it on your system +and ran some examples and tests. Stay tuned, in the next tutorial we will run kernels with different +configs to find out the best one for your hardware and task. -P.S.: Don't forget to switch out the cloud instance if you have launched one, you can find better ways to spend your money for sure! +P.S.: Don't forget to switch off the cloud instance if you have launched one, you can find better +ways to spend your money for sure! diff --git a/example/01_gemm/CMakeLists.txt b/example/01_gemm/CMakeLists.txt index c5a8295188ac2e088a320fbe25a264ef8aeec2cb..3dc2a0966e6f86882d6471f50367dc04a668a178 100644 --- a/example/01_gemm/CMakeLists.txt +++ b/example/01_gemm/CMakeLists.txt @@ -1,46 +1,73 @@ -add_custom_target(example_gemm_dl) +if(DL_KERNELS) + add_custom_target(example_gemm_dl) -add_example_executable(example_gemm_dl_fp32 gemm_dl_fp32.cpp) -add_example_executable(example_gemm_dl_fp16 gemm_dl_fp16.cpp) -add_example_executable(example_gemm_dl_int8 gemm_dl_int8.cpp) + add_example_executable(example_gemm_dl_fp32 gemm_dl_fp32.cpp) + add_dependencies(example_gemm_dl example_gemm_dl_fp32) + if(DTYPES MATCHES "fp16" OR NOT DEFINED DTYPES) + add_example_executable(example_gemm_dl_fp16 gemm_dl_fp16.cpp) + add_dependencies(example_gemm_dl example_gemm_dl_fp16) + add_example_executable(example_gemm_dpp_fp16 gemm_dpp_fp16.cpp) + endif() + if(DTYPES MATCHES "int8" OR NOT DEFINED DTYPES) + add_example_executable(example_gemm_dl_int8 gemm_dl_int8.cpp) + add_dependencies(example_gemm_dl example_gemm_dl_int8) + endif() -add_dependencies(example_gemm_dl example_gemm_dl_fp32) -add_dependencies(example_gemm_dl example_gemm_dl_fp16) -add_dependencies(example_gemm_dl example_gemm_dl_int8) + if(USE_BITINT_EXTENSION_INT4) + add_example_executable(example_gemm_dl_int4 gemm_dl_int4.cpp) + add_dependencies(example_gemm_dl example_gemm_dl_int4) + endif(USE_BITINT_EXTENSION_INT4) +endif() -if(USE_BITINT_EXTENSION_INT4) - add_example_executable(example_gemm_dl_int4 gemm_dl_int4.cpp) - add_dependencies(example_gemm_dl example_gemm_dl_int4) -endif(USE_BITINT_EXTENSION_INT4) +add_custom_target(example_gemm_xdl) +if(DTYPES MATCHES "fp16" OR NOT DEFINED DTYPES) + add_example_executable(example_gemm_xdl_fp16 gemm_xdl_fp16.cpp) + add_example_executable(example_gemm_xdl_wavelet_fp16 gemm_xdl_wavelet_fp16.cpp) + add_dependencies(example_gemm_xdl example_gemm_xdl_fp16) + add_dependencies(example_gemm_xdl example_gemm_xdl_wavelet_fp16) + add_example_executable(example_gemm_xdl_skip_b_lds_fp16 gemm_xdl_skip_b_lds_fp16.cpp) + add_dependencies(example_gemm_xdl example_gemm_xdl_skip_b_lds_fp16) + if(GPU_TARGETS MATCHES "gfx1100" OR GPU_TARGETS MATCHES "gfx1101" OR GPU_TARGETS MATCHES "gfx1102") + add_custom_target(example_gemm_wmma) + add_example_executable(example_gemm_wmma_fp16 gemm_wmma_fp16.cpp) + add_dependencies(example_gemm_wmma example_gemm_wmma_fp16) + endif() -add_custom_target(example_gemm_xdl) +endif() -add_example_executable(example_gemm_xdl_fp16 gemm_xdl_fp16.cpp) -add_example_executable(example_gemm_xdl_wavelet_fp16 gemm_xdl_wavelet_fp16.cpp) -add_example_executable(example_gemm_xdl_bf16 gemm_xdl_bf16.cpp) -add_example_executable(example_gemm_xdl_int8 gemm_xdl_int8.cpp) +if(DTYPES MATCHES "bf16" OR NOT DEFINED DTYPES) + add_example_executable(example_gemm_xdl_bf16 gemm_xdl_bf16.cpp) + add_dependencies(example_gemm_xdl example_gemm_xdl_bf16) -add_dependencies(example_gemm_xdl example_gemm_xdl_fp16) -add_dependencies(example_gemm_xdl example_gemm_xdl_bf16) -add_dependencies(example_gemm_xdl example_gemm_xdl_int8) -add_dependencies(example_gemm_xdl example_gemm_xdl_wavelet_fp16) + add_example_executable(example_gemm_xdl_bf16_rtn gemm_xdl_bf16_rtn.cpp) + add_dependencies(example_gemm_xdl example_gemm_xdl_bf16_rtn) +endif() + +if(DTYPES MATCHES "int8" OR NOT DEFINED DTYPES) + add_example_executable(example_gemm_xdl_int8 gemm_xdl_int8.cpp) + add_dependencies(example_gemm_xdl example_gemm_xdl_int8) +endif() if(USE_BITINT_EXTENSION_INT4) add_example_executable(example_gemm_xdl_int4 gemm_xdl_int4.cpp) add_dependencies(example_gemm_xdl example_gemm_xdl_int4) endif(USE_BITINT_EXTENSION_INT4) -add_example_executable(example_gemm_xdl_skip_b_lds_fp16 gemm_xdl_skip_b_lds_fp16.cpp) -# FIXME: re-enable this exampe as test when SWDEV-335738 is fixed -add_example_executable_no_testing(example_gemm_xdl_fp64 gemm_xdl_fp64.cpp) +if(DTYPES MATCHES "fp64" OR NOT DEFINED DTYPES) + # FIXME: re-enable this exampe as test when SWDEV-335738 is fixed + add_example_executable_no_testing(example_gemm_xdl_fp64 gemm_xdl_fp64.cpp) + add_dependencies(example_gemm_xdl example_gemm_xdl_fp64) +endif() -add_dependencies(example_gemm_xdl example_gemm_xdl_skip_b_lds_fp16) -add_dependencies(example_gemm_xdl example_gemm_xdl_fp64) +add_example_executable(example_gemm_xdl_streamk gemm_xdl_streamk.cpp) -if(GPU_TARGETS MATCHES "gfx1100" OR GPU_TARGETS MATCHES "gfx1101" OR GPU_TARGETS MATCHES "gfx1102") - add_custom_target(example_gemm_wmma) - add_example_executable(example_gemm_wmma_fp16 gemm_wmma_fp16.cpp) - add_dependencies(example_gemm_wmma example_gemm_wmma_fp16) +if(DTYPES MATCHES "fp8" OR NOT DEFINED DTYPES) + if(GPU_TARGETS MATCHES "gfx940" OR GPU_TARGETS MATCHES "gfx941" OR GPU_TARGETS MATCHES "gfx942") + add_example_executable(example_gemm_xdl_f8 gemm_xdl_f8.cpp) + add_dependencies(example_gemm_xdl example_gemm_xdl_f8) + endif() endif() +add_example_executable(example_gemm_xdl_fp16_f8 gemm_xdl_fp16_f8.cpp) +add_dependencies(example_gemm_xdl example_gemm_xdl_fp16_f8) diff --git a/example/01_gemm/common.hpp b/example/01_gemm/common.hpp index 495a8159623beb22d6002bb2e1667ef459b74139..7fd15b2833d2d7522fa3d87029493415fd596fad 100644 --- a/example/01_gemm/common.hpp +++ b/example/01_gemm/common.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #pragma once @@ -33,6 +33,19 @@ struct ProblemSize final ck::index_t StrideC = 4096; }; +struct ProblemSizeStreamK final +{ + ck::index_t M = 3840; + ck::index_t N = 4096; + ck::index_t K = 4096; + + ck::index_t StrideA = 4096; + ck::index_t StrideB = 4096; + ck::index_t StrideC = 4096; + + ck::index_t NumSKBlocks = -1; +}; + struct ExecutionConfig final { bool do_verification = true; @@ -48,8 +61,17 @@ using Col = ck::tensor_layout::gemm::ColumnMajor; using PassThrough = ck::tensor_operation::element_wise::PassThrough; -inline bool -parse_cmd_args(int argc, char* argv[], ProblemSize& problem_size, ExecutionConfig& config) +template +bool parse_cmd_args(int, char*[], ProblemType&, ExecutionConfig&) +{ + return false; +} + +template <> +bool parse_cmd_args(int argc, + char* argv[], + ProblemSize& problem_size, + ExecutionConfig& config) { if(argc == 1) { @@ -87,3 +109,52 @@ parse_cmd_args(int argc, char* argv[], ProblemSize& problem_size, ExecutionConfi return true; } + +template <> +bool parse_cmd_args(int argc, + char* argv[], + ProblemSizeStreamK& problem_size, + ExecutionConfig& config) +{ + if(argc == 1) + { + // use default case + } + else if(argc == 4) + { + config.do_verification = std::stoi(argv[1]); + config.init_method = std::stoi(argv[2]); + config.time_kernel = std::stoi(argv[3]); + } + else if(argc >= 10) + { + config.do_verification = std::stoi(argv[1]); + config.init_method = std::stoi(argv[2]); + config.time_kernel = std::stoi(argv[3]); + + problem_size.M = std::stoi(argv[4]); + problem_size.N = std::stoi(argv[5]); + problem_size.K = std::stoi(argv[6]); + + problem_size.StrideA = std::stoi(argv[7]); + problem_size.StrideB = std::stoi(argv[8]); + problem_size.StrideC = std::stoi(argv[9]); + + if(argc >= 11) + { + problem_size.NumSKBlocks = std::stoi(argv[10]); + } + } + else + { + std::cerr << "arg1: verification (0=no, 1=yes)" << std::endl + << "arg2: initialization (0=no init, 1=integer value, 2=decimal value)" + << std::endl + << "arg3: time kernel (0=no, 1=yes)" << std::endl + << "arg4 to 9: M (256x), N(128x), K(32x), StrideA, StrideB, StrideC" << std::endl + << "arg10: NumSKBlocks(optional)" << std::endl; + return false; + } + + return true; +} diff --git a/example/01_gemm/gemm_dl_fp16.cpp b/example/01_gemm/gemm_dl_fp16.cpp index cf585a8c51cbb9b5b0218228e4aa706189598804..b5fecb97521bd2d4213a9435b01dafe45a43be6e 100644 --- a/example/01_gemm/gemm_dl_fp16.cpp +++ b/example/01_gemm/gemm_dl_fp16.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #include "common.hpp" diff --git a/example/01_gemm/gemm_dl_fp32.cpp b/example/01_gemm/gemm_dl_fp32.cpp index 93f085cdee53667a5b906cfed3b037d57d00bf5f..212b72f2a6a060a602c5c640709da77684c7c55e 100644 --- a/example/01_gemm/gemm_dl_fp32.cpp +++ b/example/01_gemm/gemm_dl_fp32.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #include "common.hpp" diff --git a/example/01_gemm/gemm_dl_int4.cpp b/example/01_gemm/gemm_dl_int4.cpp index e392c490f29a48da3a6424f46876e220306f2907..e55ae140130c0779c1e81e0bdc84b6724c6bac86 100644 --- a/example/01_gemm/gemm_dl_int4.cpp +++ b/example/01_gemm/gemm_dl_int4.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #ifndef CK_EXPERIMENTAL_BIT_INT_EXTENSION_INT4 #error Should compile this file with ck::int4_t support diff --git a/example/01_gemm/gemm_dl_int8.cpp b/example/01_gemm/gemm_dl_int8.cpp index be9e387718f120fb1ba708d374f0ff9a09fc806d..1840390aa9e02e85c88328baf358546db362ab24 100644 --- a/example/01_gemm/gemm_dl_int8.cpp +++ b/example/01_gemm/gemm_dl_int8.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #include "common.hpp" diff --git a/example/01_gemm/gemm_dpp_fp16.cpp b/example/01_gemm/gemm_dpp_fp16.cpp new file mode 100644 index 0000000000000000000000000000000000000000..7a9e3f6186cc1525c8d882e9a961ae6ec06f95b2 --- /dev/null +++ b/example/01_gemm/gemm_dpp_fp16.cpp @@ -0,0 +1,39 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#include "common.hpp" + +#include "ck/tensor_operation/gpu/device/impl/device_gemm_dpp.hpp" + +using ADataType = ck::half_t; +using BDataType = ck::half_t; +using AccDataType = float; +using CDataType = ck::half_t; + +using F16 = ck::half_t; + +using ALayout = Row; +using BLayout = Col; +using CLayout = Row; + +using AElementOp = PassThrough; +using BElementOp = PassThrough; +using CElementOp = PassThrough; + +static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::MNPadding; + +// clang-format off +using DeviceGemmInstance = ck::tensor_operation::device::DeviceGemmDpp +// ######| AData| BData| CData| AccData| ALayout| BLayout| CLayout| A| B| C| GEMM| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MDpp| NDpp| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CThreadTransfer| CThreadTransfer| +// ######| Type| Type| Type| Type| | | | Elementwise| Elementwise| Elementwise|Spacialization| Size| Block| Block| Block| | | Dpp| Dpp| PerWave| PerWave| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| SrcDstVectorDim| DstScalar| +// ######| | | | | | | | Operation| Operation| Operation| | | | | | | | | | | | Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | | PerVector| +// ######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + < ADataType, BDataType, CDataType, AccDataType, ALayout, BLayout, CLayout, AElementOp, BElementOp, CElementOp, GemmDefault, 128, 64, 64, 64, 8, 2, 32, 8, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 2, 2, true, 5, 1>; +// // clang-format on + +using ReferenceGemmInstance = ck::tensor_operation::host:: + ReferenceGemm; + +#include "run_gemm_example.inc" + +int main(int argc, char* argv[]) { return !run_gemm_example(argc, argv); } diff --git a/example/01_gemm/gemm_wmma_fp16.cpp b/example/01_gemm/gemm_wmma_fp16.cpp index 58f965be8817bb5761dd677a52966d21c1bc1024..b11fe76ab2ce031708be7addd7f84be0f15adc6b 100644 --- a/example/01_gemm/gemm_wmma_fp16.cpp +++ b/example/01_gemm/gemm_wmma_fp16.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #include "common.hpp" diff --git a/example/01_gemm/gemm_xdl_bf16.cpp b/example/01_gemm/gemm_xdl_bf16.cpp index 9aaae6ade9564aa54bc4ea0c7f2c96aac89ff505..3cac55ef4702856dd08dc223c02ea25113dbbf32 100644 --- a/example/01_gemm/gemm_xdl_bf16.cpp +++ b/example/01_gemm/gemm_xdl_bf16.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #include "common.hpp" diff --git a/example/01_gemm/gemm_xdl_bf16_rtn.cpp b/example/01_gemm/gemm_xdl_bf16_rtn.cpp new file mode 100644 index 0000000000000000000000000000000000000000..cc14dcb8eb48f37b7b3597a02407ef440e17ab37 --- /dev/null +++ b/example/01_gemm/gemm_xdl_bf16_rtn.cpp @@ -0,0 +1,39 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#include "common.hpp" + +#include "ck/utility/type_convert.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_gemm_xdl_cshuffle.hpp" + +using ADataType = ck::bhalf_t; +using BDataType = ck::bhalf_t; +using CDataType = ck::bhalf_t; +using AccDataType = float; +using CShuffleDataType = float; + +using ALayout = Row; +using BLayout = Col; +using CLayout = Row; + +using AElementOp = PassThrough; +using BElementOp = PassThrough; +using CElementOp = ck::tensor_operation::element_wise::ConvertBF16RTN; + +static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default; + +// clang-format off +using DeviceGemmInstance = ck::tensor_operation::device::DeviceGemm_Xdl_CShuffle +// ######| ALayout| BLayout| CLayout| AData| BData| CData| AccData| CShuffle| A| B| C| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| +// ######| | | | Type| Type| Type| Type| DataType| Elementwise| Elementwise| Elementwise| Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| +// ######| | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| +// ######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + < ALayout, BLayout, CLayout, ADataType, BDataType, CDataType, AccDataType, CShuffleDataType, AElementOp, BElementOp, CElementOp, GemmDefault, 1, 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>; +// clang-format on + +using ReferenceGemmInstance = ck::tensor_operation::host:: + ReferenceGemm; + +#include "run_gemm_example.inc" + +int main(int argc, char* argv[]) { return !run_gemm_example(argc, argv); } diff --git a/example/01_gemm/gemm_xdl_f8.cpp b/example/01_gemm/gemm_xdl_f8.cpp new file mode 100644 index 0000000000000000000000000000000000000000..10159267776f6f33469affc618192e29d62763ac --- /dev/null +++ b/example/01_gemm/gemm_xdl_f8.cpp @@ -0,0 +1,38 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#include "common.hpp" + +#include "ck/tensor_operation/gpu/device/impl/device_gemm_xdl_cshuffle.hpp" + +using ADataType = ck::f8_t; +using BDataType = ck::f8_t; +using CDataType = ck::f8_t; +using AccDataType = float; +using CShuffleDataType = ck::f8_t; + +using ALayout = Row; +using BLayout = Col; +using CLayout = Row; + +using AElementOp = PassThrough; +using BElementOp = PassThrough; +using CElementOp = PassThrough; + +static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default; + +// clang-format off +using DeviceGemmInstance = ck::tensor_operation::device::DeviceGemm_Xdl_CShuffle +// ######| ALayout| BLayout| CLayout| AData| BData| CData| AccData| CShuffle| A| B| C| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| +// ######| | | | Type| Type| Type| Type| DataType| Elementwise| Elementwise| Elementwise| Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| +// ######| | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| +// ######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + < ALayout, BLayout, CLayout, ADataType, BDataType, CDataType, AccDataType, CShuffleDataType, AElementOp, BElementOp, CElementOp, GemmDefault, 1, 256, 256, 128, 64, 16, 16, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 64, 1, 4>, 16>; +// clang-format on + +using ReferenceGemmInstance = ck::tensor_operation::host:: + ReferenceGemm; + +#include "run_gemm_example.inc" + +int main(int argc, char* argv[]) { return !run_gemm_example(argc, argv); } diff --git a/example/01_gemm/gemm_xdl_fp16.cpp b/example/01_gemm/gemm_xdl_fp16.cpp index 50d35fd9ac98de821fe4861b2dbf34907d3592eb..54fbd9cdd4983c1a810f3c8a1ac1521b363af690 100644 --- a/example/01_gemm/gemm_xdl_fp16.cpp +++ b/example/01_gemm/gemm_xdl_fp16.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #include "common.hpp" diff --git a/example/01_gemm/gemm_xdl_fp16_f8.cpp b/example/01_gemm/gemm_xdl_fp16_f8.cpp new file mode 100644 index 0000000000000000000000000000000000000000..d3cf3d397a5fd329a58dc3a967aa4c8f0bb2b696 --- /dev/null +++ b/example/01_gemm/gemm_xdl_fp16_f8.cpp @@ -0,0 +1,41 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#include "common.hpp" + +#include "ck/tensor_operation/gpu/device/impl/device_gemm_xdl_cshuffle.hpp" + +using ADataType = ck::f8_t; +using BDataType = ck::half_t; +using CDataType = ck::half_t; +using AccDataType = float; +using CShuffleDataType = ck::half_t; + +using ALayout = Row; +using BLayout = Col; +using CLayout = Row; + +using AElementOp = PassThrough; +using BElementOp = PassThrough; +using CElementOp = PassThrough; + +static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default; +static constexpr auto LoopSched = ck::make_default_loop_scheduler(); +static constexpr auto PipelineVer = ck::PipelineVersion::v1; +using ComputeType = ck::half_t; + +// clang-format off +using DeviceGemmInstance = ck::tensor_operation::device::DeviceGemm_Xdl_CShuffle +// ######| ALayout| BLayout| CLayout| AData| BData| CData| AccData| CShuffle| A| B| C| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| Loop| Pipeline| ComputeType| +// ######| | | | Type| Type| Type| Type| DataType| Elementwise| Elementwise| Elementwise| Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| Scheduler| Version| | +// ######| | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| | | | +// ######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + < ALayout, BLayout, CLayout, ADataType, BDataType, CDataType, AccDataType, CShuffleDataType, AElementOp, BElementOp, CElementOp, GemmDefault, 1, 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, LoopSched, PipelineVer, ComputeType>; +// clang-format on + +using ReferenceGemmInstance = ck::tensor_operation::host:: + ReferenceGemm; + +#include "run_gemm_example.inc" + +int main(int argc, char* argv[]) { return !run_gemm_example(argc, argv); } diff --git a/example/01_gemm/gemm_xdl_fp64.cpp b/example/01_gemm/gemm_xdl_fp64.cpp index 99253b743d58707707c9765c18e5933cb09e4220..8361576299c3da3727cc467387e9717fceb93ce3 100644 --- a/example/01_gemm/gemm_xdl_fp64.cpp +++ b/example/01_gemm/gemm_xdl_fp64.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #include "common.hpp" diff --git a/example/01_gemm/gemm_xdl_int4.cpp b/example/01_gemm/gemm_xdl_int4.cpp index 7f1283a47b36c94be9645abd9fbd63094293f46e..f6238c7aa5040d8e409a139fe8c144835282cdc9 100644 --- a/example/01_gemm/gemm_xdl_int4.cpp +++ b/example/01_gemm/gemm_xdl_int4.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #ifndef CK_EXPERIMENTAL_BIT_INT_EXTENSION_INT4 #error Should compile this file with ck::int4_t support diff --git a/example/01_gemm/gemm_xdl_int8.cpp b/example/01_gemm/gemm_xdl_int8.cpp index e67594c5bcbd601ae4747ac7720dab66b605ba2c..cc03200b9d153e8e22e13e6c77a52ad76ed79174 100644 --- a/example/01_gemm/gemm_xdl_int8.cpp +++ b/example/01_gemm/gemm_xdl_int8.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #include "common.hpp" diff --git a/example/01_gemm/gemm_xdl_skip_b_lds_fp16.cpp b/example/01_gemm/gemm_xdl_skip_b_lds_fp16.cpp index 12a69925977b7eef2e5aa178016a7d433e9c0781..4a0c23cf44c6de8fe40f1f581e6016c597ae8041 100644 --- a/example/01_gemm/gemm_xdl_skip_b_lds_fp16.cpp +++ b/example/01_gemm/gemm_xdl_skip_b_lds_fp16.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #include "common.hpp" @@ -204,9 +204,9 @@ int main(int argc, char* argv[]) if(!gemm.IsSupportedArgument(argument)) { - throw std::runtime_error( - "wrong! device_gemm with the specified compilation parameters does " - "not support this GEMM problem"); + std::cerr << gemm.GetTypeString() << " does not support this problem" << std::endl; + + return 0; } float ave_time = invoker.Run(argument, StreamConfig{nullptr, time_kernel}); diff --git a/example/01_gemm/gemm_xdl_streamk.cpp b/example/01_gemm/gemm_xdl_streamk.cpp new file mode 100644 index 0000000000000000000000000000000000000000..7d433b6145017c3aa19f4ae5a9e4bf31101b1025 --- /dev/null +++ b/example/01_gemm/gemm_xdl_streamk.cpp @@ -0,0 +1,49 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#include "common.hpp" + +#include "ck/tensor_operation/gpu/device/impl/device_gemm_xdl_streamk.hpp" + +using ADataType = ck::half_t; +using BDataType = ck::half_t; +using AccDataType = float; +using CShuffleDataType = float; +using CDataType = ck::half_t; + +using F16 = ck::half_t; + +using ALayout = Row; +using BLayout = Row; +// using BLayout = Col; +using CLayout = Row; + +using AElementOp = PassThrough; +using BElementOp = PassThrough; +using CElementOp = PassThrough; + +// clang-format off +using DeviceGemmStreamK = ck::tensor_operation::device::DeviceGemmXdlStreamK +// ######| AData| BData| CData| AccData| ALayout| BLayout| CLayout| A| B| C| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| +// ######| Type| Type| Type| Type| | | | Elementwise| Elementwise| Elementwise| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| +// ######| | | | | | | | Operation| Operation| Operation| | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| +// ######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + < ADataType, BDataType, CDataType, AccDataType, ALayout, BLayout, CLayout, AElementOp, BElementOp, CElementOp, 256, 128, 128, 4, 8, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>; + + // < ADataType, BDataType, CDataType, AccDataType, ALayout, BLayout, CLayout, AElementOp, BElementOp, CElementOp, 256, 256, 128, 4, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 2, 1, 1, 1, S<1, 32, 1, 8>, 8>; + // < ADataType, BDataType, CDataType, AccDataType, ALayout, BLayout, CLayout, AElementOp, BElementOp, CElementOp, 128, 32, 64, 4, 8, 32, 32, 1, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, 1, 1, S<1, 16, 1, 8>, 8>; + // < ADataType, BDataType, CDataType, AccDataType, ALayout, BLayout, CLayout, AElementOp, BElementOp, CElementOp, 128, 32, 128, 4, 8, 32, 32, 1, 1, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<8, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 2, 1, 1, 1, S<1, 32, 1, 4>, 8>; + + + +// // clang-format on +// clang-format on + +using DeviceGemmInstance = DeviceGemmStreamK; + +using ReferenceGemmInstance = ck::tensor_operation::host:: + ReferenceGemm; + +#include "run_gemm_example.inc" + +int main(int argc, char* argv[]) { return !run_gemm_streamk_example(argc, argv); } diff --git a/example/01_gemm/gemm_xdl_wavelet_fp16.cpp b/example/01_gemm/gemm_xdl_wavelet_fp16.cpp index 3a0ddd90b700f7086b44a8b0f4e5b9d7385d0ad2..b0f963fee51c20b089811ba2b6396b9d7125d8f6 100644 --- a/example/01_gemm/gemm_xdl_wavelet_fp16.cpp +++ b/example/01_gemm/gemm_xdl_wavelet_fp16.cpp @@ -1,9 +1,9 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #include "common.hpp" -#include "ck/tensor_operation/gpu/device/device_gemm_xdl_waveletmodel_cshuffle.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_gemm_xdl_waveletmodel_cshuffle.hpp" using ADataType = ck::half_t; using BDataType = ck::half_t; diff --git a/example/01_gemm/run_gemm_example.inc b/example/01_gemm/run_gemm_example.inc index 4e2cedb52adc8c7030167ec7f171d91f0796f234..7be2539d903974cd676e92ddb3d692a997abf1cf 100644 --- a/example/01_gemm/run_gemm_example.inc +++ b/example/01_gemm/run_gemm_example.inc @@ -1,9 +1,12 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #pragma once -bool run_gemm(const ProblemSize& problem_size, const ExecutionConfig& config) +#include "ck/tensor_operation/gpu/device/device_gemm_streamk.hpp" + +template +bool run_gemm(const ProblemType& problem_size, const ExecutionConfig& config) { #if defined(BUILD_INT4_EXAMPLE) && defined(CK_EXPERIMENTAL_BIT_INT_EXTENSION_INT4) static_assert(sizeof(ck::int4_t) == sizeof(int8_t)); @@ -11,7 +14,12 @@ bool run_gemm(const ProblemSize& problem_size, const ExecutionConfig& config) using namespace ck::literals; - auto& [M, N, K, StrideA, StrideB, StrideC] = problem_size; + auto M = problem_size.M; + auto N = problem_size.N; + auto K = problem_size.K; + auto StrideA = problem_size.StrideA; + auto StrideB = problem_size.StrideB; + auto StrideC = problem_size.StrideC; auto f_host_tensor_descriptor = [](std::size_t row, std::size_t col, std::size_t stride, auto layout) { @@ -25,12 +33,37 @@ bool run_gemm(const ProblemSize& problem_size, const ExecutionConfig& config) } }; + auto f_get_default_stride = + [](std::size_t row, std::size_t col, std::size_t stride, auto layout) { + if(stride == 0) + { + // give a chance if stride is zero, return a default packed stride + if constexpr(std::is_same_v) + { + return col; + } + else + { + return row; + } + } + else + return stride; + }; + + StrideA = f_get_default_stride(M, K, StrideA, ALayout{}); + StrideB = f_get_default_stride(K, N, StrideB, BLayout{}); + StrideC = f_get_default_stride(M, N, StrideC, CLayout{}); + Tensor a_m_k(f_host_tensor_descriptor(M, K, StrideA, ALayout{})); Tensor b_k_n(f_host_tensor_descriptor(K, N, StrideB, BLayout{})); switch(config.init_method) { - case 0: break; + case 0: + ck::utils::FillConstant{static_cast(1.f)}(a_m_k); + ck::utils::FillConstant{static_cast(1.f)}(b_k_n); + break; case 1: ck::utils::FillUniformDistributionIntegerValue{-5.f, 5.f}(a_m_k); ck::utils::FillUniformDistributionIntegerValue{-5.f, 5.f}(b_k_n); @@ -66,43 +99,115 @@ bool run_gemm(const ProblemSize& problem_size, const ExecutionConfig& config) a_m_k_device_buf.ToDevice(a_m_k.mData.data()); b_k_n_device_buf.ToDevice(b_k_n.mData.data()); #endif + DeviceMem workspace; auto a_element_op = AElementOp{}; auto b_element_op = BElementOp{}; auto c_element_op = CElementOp{}; + using BaseStreamK = ck::tensor_operation::device::DeviceGemmStreamK; + // do GEMM - auto gemm = DeviceGemmInstance{}; - auto invoker = gemm.MakeInvoker(); - auto argument = gemm.MakeArgument( + auto gemm = DeviceGemmInstance{}; + auto invoker = gemm.MakeInvoker(); + float ave_time = 0; + + if constexpr(std::is_same::value && + !std::is_base_of::value) + { + auto argument = gemm.MakeArgument( #ifdef BUILD_INT4_EXAMPLE - static_cast(a_m_k_device_buf.GetDeviceBuffer()), - static_cast(b_k_n_device_buf.GetDeviceBuffer()), - static_cast(c_m_n_device_buf.GetDeviceBuffer()), + static_cast(a_m_k_device_buf.GetDeviceBuffer()), + static_cast(b_k_n_device_buf.GetDeviceBuffer()), + static_cast(c_m_n_device_buf.GetDeviceBuffer()), #else - static_cast(a_m_k_device_buf.GetDeviceBuffer()), - static_cast(b_k_n_device_buf.GetDeviceBuffer()), - static_cast(c_m_n_device_buf.GetDeviceBuffer()), + static_cast(a_m_k_device_buf.GetDeviceBuffer()), + static_cast(b_k_n_device_buf.GetDeviceBuffer()), + static_cast(c_m_n_device_buf.GetDeviceBuffer()), #endif - M, - N, - K, - StrideA, - StrideB, - StrideC, - a_element_op, - b_element_op, - c_element_op); - - if(!gemm.IsSupportedArgument(argument)) + M, + N, + K, + StrideA, + StrideB, + StrideC, + a_element_op, + b_element_op, + c_element_op); + + if(!gemm.IsSupportedArgument(argument)) + { + std::cerr << gemm.GetTypeString() << " does not support this problem" << std::endl; + + return true; + } + + ave_time = invoker.Run(argument, StreamConfig{nullptr, config.time_kernel}); + } + else if constexpr(std::is_same::value && + std::is_base_of::value) { - std::cerr << gemm.GetTypeString() << " does not support this problem" << std::endl; - - return true; + auto argument = gemm.MakeArgument( +#ifdef BUILD_INT4_EXAMPLE + static_cast(a_m_k_device_buf.GetDeviceBuffer()), + static_cast(b_k_n_device_buf.GetDeviceBuffer()), + static_cast(c_m_n_device_buf.GetDeviceBuffer()), +#else + static_cast(a_m_k_device_buf.GetDeviceBuffer()), + static_cast(b_k_n_device_buf.GetDeviceBuffer()), + static_cast(c_m_n_device_buf.GetDeviceBuffer()), +#endif + M, + N, + K, + StrideA, + StrideB, + StrideC, + a_element_op, + b_element_op, + c_element_op, + problem_size.NumSKBlocks); + + if(!gemm.IsSupportedArgument(argument)) + { + std::cerr << gemm.GetTypeString() << " does not support this problem" << std::endl; + + return true; + } + + std::size_t workspace_size = gemm.GetWorkSpaceSize(&argument); + if(workspace_size != 0) + { + workspace.Realloc(workspace_size); + gemm.SetWorkSpacePointer(&argument, workspace.GetDeviceBuffer()); + } + + ave_time = invoker.Run(argument, StreamConfig{nullptr, config.time_kernel}); + +#if 0 + // TODO!!!!! + if(workspace_size != 0){ + float * ws_ptr = reinterpret_cast(malloc(workspace_size)); + size_t ws_dwords = workspace_size / sizeof(float); + workspace.FromDevice(ws_ptr); + + for(size_t i = 0; i < ws_dwords; i++) { + uint32_t rere = reinterpret_cast(ws_ptr)[i]; + printf("%4lu : %f(0x%08x)\n", i, ws_ptr[i], rere); + } + free(ws_ptr); + } +#endif } - float ave_time = invoker.Run(argument, StreamConfig{nullptr, config.time_kernel}); - std::size_t flop = 2_uz * M * N * K; std::size_t num_btype = sizeof(ADataType) * M * K + sizeof(BDataType) * K * N + sizeof(CDataType) * M * N; @@ -149,3 +254,11 @@ bool run_gemm_example(int argc, char* argv[]) return !parse_cmd_args(argc, argv, problem_size, config) || run_gemm(problem_size, config); } + +bool run_gemm_streamk_example(int argc, char* argv[]) +{ + ProblemSizeStreamK problem_size; + ExecutionConfig config; + + return !parse_cmd_args(argc, argv, problem_size, config) || run_gemm(problem_size, config); +} diff --git a/example/02_gemm_bilinear/CMakeLists.txt b/example/02_gemm_bilinear/CMakeLists.txt index dda234e6055d54b2c5a28528f625a127b4a0e10d..52e63305232598bb0c81419b92cc81ac97c7e73e 100644 --- a/example/02_gemm_bilinear/CMakeLists.txt +++ b/example/02_gemm_bilinear/CMakeLists.txt @@ -1,7 +1,22 @@ -if(GPU_TARGETS MATCHES "gfx1100" OR GPU_TARGETS MATCHES "gfx1101" OR GPU_TARGETS MATCHES "gfx1102") +if(DTYPES MATCHES "fp16" OR NOT DEFINED DTYPES) +list(APPEND gpu_list1 gfx1100 gfx1101 gfx1102) +list(APPEND gpu_list2 gfx908 gfx90a gfx940 gfx941 gfx942) +set(target 0) +foreach(gpu IN LISTS GPU_TARGETS) + if(gpu IN_LIST gpu_list1 AND target EQUAL 0) add_example_executable(example_gemm_bilinear_wmma_fp16 gemm_bilinear_wmma_fp16.cpp) add_example_executable(example_gemm_bilinear_wmma_int8 gemm_bilinear_wmma_int8.cpp) endif() if(GPU_TARGETS MATCHES "gfx908" OR GPU_TARGETS MATCHES "gfx90a" OR GPU_TARGETS MATCHES "gfx940") + set(target 1) + endif() +endforeach() + +set(target 0) +foreach(gpu IN LISTS GPU_TARGETS) + if(gpu IN_LIST gpu_list2 AND target EQUAL 0) add_example_executable(example_gemm_bilinear_xdl_fp16 gemm_bilinear_xdl_fp16.cpp) + set(target 1) + endif() +endforeach() endif() diff --git a/example/02_gemm_bilinear/gemm_bilinear_wmma_fp16.cpp b/example/02_gemm_bilinear/gemm_bilinear_wmma_fp16.cpp index ff99bf4641115e466775995de2761ba52897783d..877792d7409f48ee562cded6629f7e1bbde650d9 100644 --- a/example/02_gemm_bilinear/gemm_bilinear_wmma_fp16.cpp +++ b/example/02_gemm_bilinear/gemm_bilinear_wmma_fp16.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #include #include diff --git a/example/02_gemm_bilinear/gemm_bilinear_xdl_fp16.cpp b/example/02_gemm_bilinear/gemm_bilinear_xdl_fp16.cpp index 917b6b1c3142ec8cf7d1c852ea59494fea1842a9..c3e6ef7d5df2e12de69e2f9465f8f3e3709c0859 100644 --- a/example/02_gemm_bilinear/gemm_bilinear_xdl_fp16.cpp +++ b/example/02_gemm_bilinear/gemm_bilinear_xdl_fp16.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #include #include diff --git a/example/03_gemm_bias_relu/CMakeLists.txt b/example/03_gemm_bias_relu/CMakeLists.txt index 8834a910f0367946c96d44e8cfeef653be322512..a247a052cba40410e8879a57a0e3c35a57b5645f 100644 --- a/example/03_gemm_bias_relu/CMakeLists.txt +++ b/example/03_gemm_bias_relu/CMakeLists.txt @@ -1,3 +1,10 @@ -if(GPU_TARGETS MATCHES "gfx908" OR GPU_TARGETS MATCHES "gfx90a" OR GPU_TARGETS MATCHES "gfx940") +if(DTYPES MATCHES "fp16" OR NOT DEFINED DTYPES) +list(APPEND gpu_list gfx908 gfx90a gfx940 gfx941 gfx942) +set(target 0) +foreach(gpu IN LISTS GPU_TARGETS) + if(gpu IN_LIST gpu_list AND target EQUAL 0) add_example_executable(example_gemm_bias_relu_xdl_fp16 gemm_bias_relu_xdl_fp16.cpp) -endif() \ No newline at end of file + set(target 1) + endif() +endforeach() +endif() diff --git a/example/03_gemm_bias_relu/gemm_bias_relu_xdl_fp16.cpp b/example/03_gemm_bias_relu/gemm_bias_relu_xdl_fp16.cpp index aee51d05de58b6386a5dd267d45a7e5c12d98276..dffeff23374b6f868fcce952ef8991ca289f59b7 100644 --- a/example/03_gemm_bias_relu/gemm_bias_relu_xdl_fp16.cpp +++ b/example/03_gemm_bias_relu/gemm_bias_relu_xdl_fp16.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #include #include diff --git a/example/04_gemm_add_add_fastgelu/CMakeLists.txt b/example/04_gemm_add_add_fastgelu/CMakeLists.txt index a706830b68f5fea22e533e4ec03dd9c98e311b27..15ec62c89fcd39e24e7bb1ae5f8e7d366f2138d6 100644 --- a/example/04_gemm_add_add_fastgelu/CMakeLists.txt +++ b/example/04_gemm_add_add_fastgelu/CMakeLists.txt @@ -1,19 +1,28 @@ -if(GPU_TARGETS MATCHES "gfx908" OR GPU_TARGETS MATCHES "gfx90a" OR GPU_TARGETS MATCHES "gfx940") +list(APPEND gpu_list gfx908 gfx90a gfx940 gfx941 gfx942) +set(target 0) +foreach(gpu IN LISTS GPU_TARGETS) + if(gpu IN_LIST gpu_list AND target EQUAL 0) add_custom_target(example_gemm_add_add_fastgelu_xdl) - - add_example_executable(example_gemm_add_add_fastgelu_xdl_bf16 gemm_add_add_fastgelu_xdl_bf16.cpp) - add_example_executable(example_gemm_add_add_fastgelu_xdl_fp16 gemm_add_add_fastgelu_xdl_fp16.cpp) - add_example_executable(example_gemm_add_add_fastgelu_xdl_fp32 gemm_add_add_fastgelu_xdl_fp32.cpp) + if(DTYPES MATCHES "bf16" OR NOT DEFINED DTYPES) + add_example_executable(example_gemm_add_add_fastgelu_xdl_bf16 gemm_add_add_fastgelu_xdl_bf16.cpp) + add_dependencies(example_gemm_add_add_fastgelu_xdl example_gemm_add_add_fastgelu_xdl_bf16) + endif() + if(DTYPES MATCHES "fp16" OR NOT DEFINED DTYPES) + add_example_executable(example_gemm_add_add_fastgelu_xdl_fp16 gemm_add_add_fastgelu_xdl_fp16.cpp) + add_dependencies(example_gemm_add_add_fastgelu_xdl example_gemm_add_add_fastgelu_xdl_fp16) + endif() + if(DTYPES MATCHES "fp32" OR NOT DEFINED DTYPES) + add_example_executable(example_gemm_add_add_fastgelu_xdl_fp32 gemm_add_add_fastgelu_xdl_fp32.cpp) + add_dependencies(example_gemm_add_add_fastgelu_xdl example_gemm_add_add_fastgelu_xdl_fp32) + endif() if(USE_BITINT_EXTENSION_INT4) add_example_executable(example_gemm_add_add_fastgelu_xdl_int4 gemm_add_add_fastgelu_xdl_int4.cpp) + add_dependencies(example_gemm_add_add_fastgelu_xdl example_gemm_add_add_fastgelu_xdl_int4) endif(USE_BITINT_EXTENSION_INT4) - add_example_executable(example_gemm_add_add_fastgelu_xdl_int8 gemm_add_add_fastgelu_xdl_int8.cpp) - - add_dependencies(example_gemm_add_add_fastgelu_xdl example_gemm_add_add_fastgelu_xdl_bf16) - add_dependencies(example_gemm_add_add_fastgelu_xdl example_gemm_add_add_fastgelu_xdl_fp16) - add_dependencies(example_gemm_add_add_fastgelu_xdl example_gemm_add_add_fastgelu_xdl_fp32) - if(USE_BITINT_EXTENSION_INT4) - add_dependencies(example_gemm_add_add_fastgelu_xdl example_gemm_add_add_fastgelu_xdl_int4) - endif(USE_BITINT_EXTENSION_INT4) - add_dependencies(example_gemm_add_add_fastgelu_xdl example_gemm_add_add_fastgelu_xdl_int8) -endif() \ No newline at end of file + if(DTYPES MATCHES "int8" OR NOT DEFINED DTYPES) + add_example_executable(example_gemm_add_add_fastgelu_xdl_int8 gemm_add_add_fastgelu_xdl_int8.cpp) + add_dependencies(example_gemm_add_add_fastgelu_xdl example_gemm_add_add_fastgelu_xdl_int8) + endif() + set(target 1) + endif() +endforeach() \ No newline at end of file diff --git a/example/04_gemm_add_add_fastgelu/common.hpp b/example/04_gemm_add_add_fastgelu/common.hpp index 839587c148956cf186e1ba5ac286a1b73797fade..91d17df95fc470ea23011999d0a929b9331d0cc2 100644 --- a/example/04_gemm_add_add_fastgelu/common.hpp +++ b/example/04_gemm_add_add_fastgelu/common.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/example/04_gemm_add_add_fastgelu/gemm_add_add_fastgelu_xdl_bf16.cpp b/example/04_gemm_add_add_fastgelu/gemm_add_add_fastgelu_xdl_bf16.cpp index ba0476b9b9e3bc031cb0ea61eda66970c2544f39..e630f6783713517977f63a9a2a1ef20836af2c61 100644 --- a/example/04_gemm_add_add_fastgelu/gemm_add_add_fastgelu_xdl_bf16.cpp +++ b/example/04_gemm_add_add_fastgelu/gemm_add_add_fastgelu_xdl_bf16.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #include "common.hpp" diff --git a/example/04_gemm_add_add_fastgelu/gemm_add_add_fastgelu_xdl_fp16.cpp b/example/04_gemm_add_add_fastgelu/gemm_add_add_fastgelu_xdl_fp16.cpp index b940bfd89737ff4f460ef2b672a99a6696d57ce3..71f6677bae95d1779c76219f6c8b5bc05c3c1d31 100644 --- a/example/04_gemm_add_add_fastgelu/gemm_add_add_fastgelu_xdl_fp16.cpp +++ b/example/04_gemm_add_add_fastgelu/gemm_add_add_fastgelu_xdl_fp16.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #include "common.hpp" diff --git a/example/04_gemm_add_add_fastgelu/gemm_add_add_fastgelu_xdl_fp32.cpp b/example/04_gemm_add_add_fastgelu/gemm_add_add_fastgelu_xdl_fp32.cpp index fa651a34ea86298a60dbea6158c356b9e50fdfce..4665c3932f419c837125ea3153437302cb9dc022 100644 --- a/example/04_gemm_add_add_fastgelu/gemm_add_add_fastgelu_xdl_fp32.cpp +++ b/example/04_gemm_add_add_fastgelu/gemm_add_add_fastgelu_xdl_fp32.cpp @@ -1,4 +1,4 @@ -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #include "common.hpp" diff --git a/example/04_gemm_add_add_fastgelu/gemm_add_add_fastgelu_xdl_int4.cpp b/example/04_gemm_add_add_fastgelu/gemm_add_add_fastgelu_xdl_int4.cpp index 9f9c423de278b8687816dac1d518f843eb2eb34a..f206bbeb411bb0de0144015d70a892b5587d6f0e 100644 --- a/example/04_gemm_add_add_fastgelu/gemm_add_add_fastgelu_xdl_int4.cpp +++ b/example/04_gemm_add_add_fastgelu/gemm_add_add_fastgelu_xdl_int4.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #ifndef CK_EXPERIMENTAL_BIT_INT_EXTENSION_INT4 #error Should compile this file with ck::int4_t support diff --git a/example/04_gemm_add_add_fastgelu/gemm_add_add_fastgelu_xdl_int8.cpp b/example/04_gemm_add_add_fastgelu/gemm_add_add_fastgelu_xdl_int8.cpp index fadc4ef5ee47737bb33e19f1bab3264f99de5ea2..e46483ab38ac34cb5d4f97602474887d78779327 100644 --- a/example/04_gemm_add_add_fastgelu/gemm_add_add_fastgelu_xdl_int8.cpp +++ b/example/04_gemm_add_add_fastgelu/gemm_add_add_fastgelu_xdl_int8.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #include "common.hpp" diff --git a/example/09_convnd_fwd/CMakeLists.txt b/example/09_convnd_fwd/CMakeLists.txt index 1bcf2d148d0f23aa42126ed50867a9403293197c..1af1e6c8582f6f0f4ccf3902dc1dd0c1d5bd02d5 100644 --- a/example/09_convnd_fwd/CMakeLists.txt +++ b/example/09_convnd_fwd/CMakeLists.txt @@ -1,12 +1,35 @@ -if(GPU_TARGETS MATCHES "gfx908" OR GPU_TARGETS MATCHES "gfx90a" OR GPU_TARGETS MATCHES "gfx940") - add_example_executable(example_convnd_fwd_xdl_fp32 convnd_fwd_xdl_fp32.cpp) - add_example_executable(example_convnd_fwd_xdl_fp16 convnd_fwd_xdl_fp16.cpp) - add_example_executable(example_convnd_fwd_xdl_bf16 convnd_fwd_xdl_bf16.cpp) - add_example_executable(example_convnd_fwd_xdl_int8 convnd_fwd_xdl_int8.cpp) +list(APPEND gpu_list gfx908 gfx90a gfx940 gfx941 gfx942) +set(target 0) +foreach(gpu IN LISTS GPU_TARGETS) + if(gpu IN_LIST gpu_list AND target EQUAL 0) + if(DTYPES MATCHES "fp32" OR NOT DEFINED DTYPES) + add_example_executable(example_convnd_fwd_xdl_fp32 convnd_fwd_xdl_fp32.cpp) + endif() + if(DTYPES MATCHES "fp16" OR NOT DEFINED DTYPES) + add_example_executable(example_convnd_fwd_xdl_fp16 convnd_fwd_xdl_fp16.cpp) + endif() + if(DTYPES MATCHES "bf16" OR NOT DEFINED DTYPES) + add_example_executable(example_convnd_fwd_xdl_bf16 convnd_fwd_xdl_bf16.cpp) + endif() + if(DTYPES MATCHES "int8" OR NOT DEFINED DTYPES) + add_example_executable(example_convnd_fwd_xdl_int8 convnd_fwd_xdl_int8.cpp) + endif() # FIXME: re-enable this exampe as test when SWDEV-335738 is fixed - add_example_executable_no_testing(example_convnd_fwd_xdl_fp64 convnd_fwd_xdl_fp64.cpp) -endif() -add_example_executable(example_convnd_fwd_dl_fp16 convnd_fwd_dl_fp16.cpp) -add_example_executable(example_convnd_fwd_dl_fp32 convnd_fwd_dl_fp32.cpp) -add_example_executable(example_convnd_fwd_dl_int8 convnd_fwd_dl_int8.cpp) + if(DTYPES MATCHES "fp64" OR NOT DEFINED DTYPES) + add_example_executable_no_testing(example_convnd_fwd_xdl_fp64 convnd_fwd_xdl_fp64.cpp) + endif() + set(target 1) + endif() +endforeach() +if(DL_KERNELS) + if(DTYPES MATCHES "fp16" OR NOT DEFINED DTYPES) + add_example_executable(example_convnd_fwd_dl_fp16 convnd_fwd_dl_fp16.cpp) + endif() + if(DTYPES MATCHES "fp32" OR NOT DEFINED DTYPES) + add_example_executable(example_convnd_fwd_dl_fp32 convnd_fwd_dl_fp32.cpp) + endif() + if(DTYPES MATCHES "int8" OR NOT DEFINED DTYPES) + add_example_executable(example_convnd_fwd_dl_int8 convnd_fwd_dl_int8.cpp) + endif() +endif() diff --git a/example/09_convnd_fwd/convnd_fwd_common.hpp b/example/09_convnd_fwd/convnd_fwd_common.hpp index 4c594ccdf817a069918b348d7b23c6cc533bde28..109b8f9ee34096601f0dd485db6840e0674c1937 100644 --- a/example/09_convnd_fwd/convnd_fwd_common.hpp +++ b/example/09_convnd_fwd/convnd_fwd_common.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #include #include diff --git a/example/09_convnd_fwd/convnd_fwd_dl_common.hpp b/example/09_convnd_fwd/convnd_fwd_dl_common.hpp index 855710b9d9a30ec47c74a6f31d1b03ade688e28e..aeddd4fc59fc521a196487bf7c0ea4e88bf6e1c8 100644 --- a/example/09_convnd_fwd/convnd_fwd_dl_common.hpp +++ b/example/09_convnd_fwd/convnd_fwd_dl_common.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #include #include diff --git a/example/09_convnd_fwd/convnd_fwd_dl_fp16.cpp b/example/09_convnd_fwd/convnd_fwd_dl_fp16.cpp index db5a7f0bc3351c9f1670993942a54668c191e56f..564fadcbf8b10acc98e47a67cddd2cb0abe295bc 100644 --- a/example/09_convnd_fwd/convnd_fwd_dl_fp16.cpp +++ b/example/09_convnd_fwd/convnd_fwd_dl_fp16.cpp @@ -1,9 +1,9 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #include "convnd_fwd_dl_common.hpp" -#include "ck/tensor_operation/gpu/device/device_grouped_conv_fwd_dl_multiple_d_nhwc_kyxc_nhwk.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_dl_multiple_d_nhwc_kyxc_nhwk.hpp" #include "ck/library/utility/convolution_host_tensor_descriptor_helper.hpp" diff --git a/example/09_convnd_fwd/convnd_fwd_dl_fp32.cpp b/example/09_convnd_fwd/convnd_fwd_dl_fp32.cpp index 964d784c8592b2e94bc38cc8eab6b7a645d88fc1..0bd90c999233b5392f5914081196ff3b06cc8285 100644 --- a/example/09_convnd_fwd/convnd_fwd_dl_fp32.cpp +++ b/example/09_convnd_fwd/convnd_fwd_dl_fp32.cpp @@ -1,9 +1,9 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #include "convnd_fwd_dl_common.hpp" -#include "ck/tensor_operation/gpu/device/device_grouped_conv_fwd_dl_multiple_d_nhwc_kyxc_nhwk.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_dl_multiple_d_nhwc_kyxc_nhwk.hpp" #include "ck/library/utility/convolution_host_tensor_descriptor_helper.hpp" diff --git a/example/09_convnd_fwd/convnd_fwd_dl_int8.cpp b/example/09_convnd_fwd/convnd_fwd_dl_int8.cpp index b0cd88f214c8550c018ad9f4d268370dbd816c7f..9b0c8b31f172dcaf984c7c0c728a65fcb1a0554a 100644 --- a/example/09_convnd_fwd/convnd_fwd_dl_int8.cpp +++ b/example/09_convnd_fwd/convnd_fwd_dl_int8.cpp @@ -1,9 +1,9 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #include "convnd_fwd_dl_common.hpp" -#include "ck/tensor_operation/gpu/device/device_grouped_conv_fwd_dl_multiple_d_nhwc_kyxc_nhwk.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_dl_multiple_d_nhwc_kyxc_nhwk.hpp" #include "ck/library/utility/convolution_host_tensor_descriptor_helper.hpp" diff --git a/example/09_convnd_fwd/convnd_fwd_xdl_bf16.cpp b/example/09_convnd_fwd/convnd_fwd_xdl_bf16.cpp index d55d3154916b0f4ba2fee195b7957904ae3d3139..74cf91d16017bd24096dac6348e5a8815856a51a 100644 --- a/example/09_convnd_fwd/convnd_fwd_xdl_bf16.cpp +++ b/example/09_convnd_fwd/convnd_fwd_xdl_bf16.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #include "convnd_fwd_common.hpp" diff --git a/example/09_convnd_fwd/convnd_fwd_xdl_fp16.cpp b/example/09_convnd_fwd/convnd_fwd_xdl_fp16.cpp index d84afba6426b1abe4b10a3b4da1fefe7c2e9272a..f6d69bafd48a4b8bff968c2fb8600793b1310294 100644 --- a/example/09_convnd_fwd/convnd_fwd_xdl_fp16.cpp +++ b/example/09_convnd_fwd/convnd_fwd_xdl_fp16.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #include "convnd_fwd_common.hpp" diff --git a/example/09_convnd_fwd/convnd_fwd_xdl_fp32.cpp b/example/09_convnd_fwd/convnd_fwd_xdl_fp32.cpp index f5acc540cf98b8876caa0d5c0812436dd200deb8..6c3171f6157ac4a83a488af12089a97accfd2c7c 100644 --- a/example/09_convnd_fwd/convnd_fwd_xdl_fp32.cpp +++ b/example/09_convnd_fwd/convnd_fwd_xdl_fp32.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #include "convnd_fwd_common.hpp" diff --git a/example/09_convnd_fwd/convnd_fwd_xdl_fp64.cpp b/example/09_convnd_fwd/convnd_fwd_xdl_fp64.cpp index 8d697976abd43e5d07c49ec4e0819cbdf15ac02b..9977a496d229876c083e0afe99e17ff0fa5fca05 100644 --- a/example/09_convnd_fwd/convnd_fwd_xdl_fp64.cpp +++ b/example/09_convnd_fwd/convnd_fwd_xdl_fp64.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #include "convnd_fwd_common.hpp" diff --git a/example/09_convnd_fwd/convnd_fwd_xdl_int8.cpp b/example/09_convnd_fwd/convnd_fwd_xdl_int8.cpp index 99f7f2565c748a950841dffa12a1d343df0051ab..bf084b3cc0b6ceab326b2b21e963e0ce2d4bdcad 100644 --- a/example/09_convnd_fwd/convnd_fwd_xdl_int8.cpp +++ b/example/09_convnd_fwd/convnd_fwd_xdl_int8.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #include "convnd_fwd_common.hpp" diff --git a/example/09_convnd_fwd/run_convnd_fwd_dl_example.inc b/example/09_convnd_fwd/run_convnd_fwd_dl_example.inc index 697ada14ba960f7231db3aa0d5b2482ca38578b8..6474df1c355bc059af5429f96616b2192f46f292 100644 --- a/example/09_convnd_fwd/run_convnd_fwd_dl_example.inc +++ b/example/09_convnd_fwd/run_convnd_fwd_dl_example.inc @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/example/09_convnd_fwd/run_convnd_fwd_example.inc b/example/09_convnd_fwd/run_convnd_fwd_example.inc index 36a68056f1d7a34b4309a06941f0ff1f477eaba8..49852ff6678f7c3fbbfce4ac8b1303876092cf40 100644 --- a/example/09_convnd_fwd/run_convnd_fwd_example.inc +++ b/example/09_convnd_fwd/run_convnd_fwd_example.inc @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/example/10_convnd_fwd_multiple_d_multiple_reduce/CMakeLists.txt b/example/10_convnd_fwd_multiple_d_multiple_reduce/CMakeLists.txt index de26462f693da0bd72cb7d76c0b0a0de875ee564..e7d941ae6b3dddee9f8dfa357cf3426045d0e828 100644 --- a/example/10_convnd_fwd_multiple_d_multiple_reduce/CMakeLists.txt +++ b/example/10_convnd_fwd_multiple_d_multiple_reduce/CMakeLists.txt @@ -1,15 +1,28 @@ -if(GPU_TARGETS MATCHES "gfx908" OR GPU_TARGETS MATCHES "gfx90a" OR GPU_TARGETS MATCHES "gfx940") +list(APPEND gpu_list gfx908 gfx90a gfx940 gfx941 gfx942) +set(target 0) +foreach(gpu IN LISTS GPU_TARGETS) + if(gpu IN_LIST gpu_list AND target EQUAL 0) add_custom_target(example_convnd_fwd_reduce_xdl) - add_example_executable(example_convnd_fwd_max_xdl_int8 convnd_fwd_max_xdl_int8.cpp) - add_example_executable_no_testing(example_convnd_fwd_max_xdl_bf16 convnd_fwd_max_xdl_bf16.cpp) - add_example_executable_no_testing(example_convnd_fwd_max_xdl_fp16 convnd_fwd_max_xdl_fp16.cpp) - add_example_executable(example_convnd_fwd_max_xdl_fp32 convnd_fwd_max_xdl_fp32.cpp) - add_dependencies(example_convnd_fwd_reduce_xdl example_convnd_fwd_max_xdl_int8) - add_dependencies(example_convnd_fwd_reduce_xdl example_convnd_fwd_max_xdl_bf16) - add_dependencies(example_convnd_fwd_reduce_xdl example_convnd_fwd_max_xdl_fp16) - add_dependencies(example_convnd_fwd_reduce_xdl example_convnd_fwd_max_xdl_fp32) + if(DTYPES MATCHES "int8" OR NOT DEFINED DTYPES) + add_example_executable(example_convnd_fwd_max_xdl_int8 convnd_fwd_max_xdl_int8.cpp) + add_dependencies(example_convnd_fwd_reduce_xdl example_convnd_fwd_max_xdl_int8) + endif() + if(DTYPES MATCHES "bf16" OR NOT DEFINED DTYPES) + add_example_executable_no_testing(example_convnd_fwd_max_xdl_bf16 convnd_fwd_max_xdl_bf16.cpp) + add_dependencies(example_convnd_fwd_reduce_xdl example_convnd_fwd_max_xdl_bf16) + endif() + if(DTYPES MATCHES "fp16" OR NOT DEFINED DTYPES) + add_example_executable_no_testing(example_convnd_fwd_max_xdl_fp16 convnd_fwd_max_xdl_fp16.cpp) + add_dependencies(example_convnd_fwd_reduce_xdl example_convnd_fwd_max_xdl_fp16) + endif() + if(DTYPES MATCHES "fp32" OR NOT DEFINED DTYPES) + add_example_executable(example_convnd_fwd_max_xdl_fp32 convnd_fwd_max_xdl_fp32.cpp) + add_dependencies(example_convnd_fwd_reduce_xdl example_convnd_fwd_max_xdl_fp32) + endif() if(USE_BITINT_EXTENSION_INT4) add_example_executable(example_convnd_fwd_max_xdl_int4 convnd_fwd_max_xdl_int4.cpp) add_dependencies(example_convnd_fwd_reduce_xdl example_convnd_fwd_max_xdl_int4) endif(USE_BITINT_EXTENSION_INT4) -endif() \ No newline at end of file + set(target 1) + endif() +endforeach() \ No newline at end of file diff --git a/example/10_convnd_fwd_multiple_d_multiple_reduce/common.hpp b/example/10_convnd_fwd_multiple_d_multiple_reduce/common.hpp index 00e370f2968df0409155d5407066e863574613ce..137b0d1ff0fb0a7cb7159d52f54b4cd3496629ae 100644 --- a/example/10_convnd_fwd_multiple_d_multiple_reduce/common.hpp +++ b/example/10_convnd_fwd_multiple_d_multiple_reduce/common.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #include #include diff --git a/example/10_convnd_fwd_multiple_d_multiple_reduce/convnd_fwd_max_xdl_bf16.cpp b/example/10_convnd_fwd_multiple_d_multiple_reduce/convnd_fwd_max_xdl_bf16.cpp index 6ff29b4b0ff0f11d6c9d3becbd3aaafdccf3020e..4ccacb0bcee93e72eb1fd23203e9404f1c23c478 100644 --- a/example/10_convnd_fwd_multiple_d_multiple_reduce/convnd_fwd_max_xdl_bf16.cpp +++ b/example/10_convnd_fwd_multiple_d_multiple_reduce/convnd_fwd_max_xdl_bf16.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #include "common.hpp" diff --git a/example/10_convnd_fwd_multiple_d_multiple_reduce/convnd_fwd_max_xdl_fp16.cpp b/example/10_convnd_fwd_multiple_d_multiple_reduce/convnd_fwd_max_xdl_fp16.cpp index 02c19c2b63bf4f9ee663140b414c89f6221cff51..bf495725e8e77a37019d1d858bf3fe3b69e828d1 100644 --- a/example/10_convnd_fwd_multiple_d_multiple_reduce/convnd_fwd_max_xdl_fp16.cpp +++ b/example/10_convnd_fwd_multiple_d_multiple_reduce/convnd_fwd_max_xdl_fp16.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #include "common.hpp" diff --git a/example/10_convnd_fwd_multiple_d_multiple_reduce/convnd_fwd_max_xdl_fp32.cpp b/example/10_convnd_fwd_multiple_d_multiple_reduce/convnd_fwd_max_xdl_fp32.cpp index 679bb5c0c45a1e2e74fc56d21ee6f7f73cae76ad..5848785673c340149d622fd782671d5fb174edab 100644 --- a/example/10_convnd_fwd_multiple_d_multiple_reduce/convnd_fwd_max_xdl_fp32.cpp +++ b/example/10_convnd_fwd_multiple_d_multiple_reduce/convnd_fwd_max_xdl_fp32.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #include "common.hpp" diff --git a/example/10_convnd_fwd_multiple_d_multiple_reduce/convnd_fwd_max_xdl_int4.cpp b/example/10_convnd_fwd_multiple_d_multiple_reduce/convnd_fwd_max_xdl_int4.cpp index abdbdaf74d5c0bd2b87da17d315e8712cd896d52..bf7127502faac94e6858d37c389f316b7ead35a6 100644 --- a/example/10_convnd_fwd_multiple_d_multiple_reduce/convnd_fwd_max_xdl_int4.cpp +++ b/example/10_convnd_fwd_multiple_d_multiple_reduce/convnd_fwd_max_xdl_int4.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #ifndef CK_EXPERIMENTAL_BIT_INT_EXTENSION_INT4 #error Should compile this file with ck::int4_t support diff --git a/example/10_convnd_fwd_multiple_d_multiple_reduce/convnd_fwd_max_xdl_int8.cpp b/example/10_convnd_fwd_multiple_d_multiple_reduce/convnd_fwd_max_xdl_int8.cpp index cf86afa8e94957c01fc92acb5ed2286fcb52466c..3e1694cbe8cae679ef249b6654efc40f8b1fc1ab 100644 --- a/example/10_convnd_fwd_multiple_d_multiple_reduce/convnd_fwd_max_xdl_int8.cpp +++ b/example/10_convnd_fwd_multiple_d_multiple_reduce/convnd_fwd_max_xdl_int8.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #include "common.hpp" diff --git a/example/10_convnd_fwd_multiple_d_multiple_reduce/run_convnd_fwd_max_example.inc b/example/10_convnd_fwd_multiple_d_multiple_reduce/run_convnd_fwd_max_example.inc index b3a3891781766acd824bc1974b88f57fdd85b711..cebfeb51d63eac30243d8e3b0468b821f6fc1eb3 100644 --- a/example/10_convnd_fwd_multiple_d_multiple_reduce/run_convnd_fwd_max_example.inc +++ b/example/10_convnd_fwd_multiple_d_multiple_reduce/run_convnd_fwd_max_example.inc @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/example/12_reduce/reduce_blockwise.cpp b/example/12_reduce/reduce_blockwise.cpp index a7ee9990c1941a73b632f3cc1d32b14a00897cc2..9a736d4cfac9d1062c568e795592711e30d18586 100644 --- a/example/12_reduce/reduce_blockwise.cpp +++ b/example/12_reduce/reduce_blockwise.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #include #include diff --git a/example/12_reduce/reduce_blockwise_impl.hpp b/example/12_reduce/reduce_blockwise_impl.hpp index e6e3cc8d52bc9fe6568d98194e644cf75ad1950a..7f8394a7301cc3da8271e1e8c41482d0b06a2859 100644 --- a/example/12_reduce/reduce_blockwise_impl.hpp +++ b/example/12_reduce/reduce_blockwise_impl.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/example/12_reduce/reduce_blockwise_two_call.cpp b/example/12_reduce/reduce_blockwise_two_call.cpp index dbb18a0d83f62d27750c7797842af9bea0eb1383..eb8b5c76d31a1a0fc7a1a7e085416f80be9abf7f 100644 --- a/example/12_reduce/reduce_blockwise_two_call.cpp +++ b/example/12_reduce/reduce_blockwise_two_call.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #include #include diff --git a/example/12_reduce/reduce_example_common.hpp b/example/12_reduce/reduce_example_common.hpp index 05f0a0edb25ba520d475522bd2963bc050ccc010..5f9a48804a7268cef5a4a065d0c875565bdb3e9e 100644 --- a/example/12_reduce/reduce_example_common.hpp +++ b/example/12_reduce/reduce_example_common.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/example/12_reduce/reduce_multiblock_atomic_add.cpp b/example/12_reduce/reduce_multiblock_atomic_add.cpp index c4d63a3add8b0d8016b11b8834a1385dcca5beb5..120e3f05957fdb9f937b290dfe09458db248f9c0 100644 --- a/example/12_reduce/reduce_multiblock_atomic_add.cpp +++ b/example/12_reduce/reduce_multiblock_atomic_add.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #include #include diff --git a/example/12_reduce/reduce_multiblock_atomic_add_impl.hpp b/example/12_reduce/reduce_multiblock_atomic_add_impl.hpp index 905242fb6b5bca3c0dd7209a5d2827546036695e..fed62186448d58ce58c2a13c6dd184e205801185 100644 --- a/example/12_reduce/reduce_multiblock_atomic_add_impl.hpp +++ b/example/12_reduce/reduce_multiblock_atomic_add_impl.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/example/13_pool2d_fwd/CMakeLists.txt b/example/13_pool2d_fwd/CMakeLists.txt index db09c03321e5ac0f3408dcffbc940b6c9c0879c5..d0f356757b3845cb1a72708f0cd5c367eb1042fc 100644 --- a/example/13_pool2d_fwd/CMakeLists.txt +++ b/example/13_pool2d_fwd/CMakeLists.txt @@ -1,3 +1,6 @@ -add_example_executable(example_pool2d_fwd_fp16 pool2d_fwd_fp16.cpp) -add_example_executable(example_pool2d_fwd_fp32 pool2d_fwd_fp32.cpp) - +if(DTYPES MATCHES "fp16" OR NOT DEFINED DTYPES) + add_example_executable(example_pool2d_fwd_fp16 pool2d_fwd_fp16.cpp) +endif() +if(DTYPES MATCHES "fp32" OR NOT DEFINED DTYPES) + add_example_executable(example_pool2d_fwd_fp32 pool2d_fwd_fp32.cpp) +endif() diff --git a/example/13_pool2d_fwd/pool2d_fwd_common.hpp b/example/13_pool2d_fwd/pool2d_fwd_common.hpp index 9abc98671389a3ebbadab0924d4334a6926feef5..3ce08fd2afceb083f16ff26e88caa7a0855779b0 100644 --- a/example/13_pool2d_fwd/pool2d_fwd_common.hpp +++ b/example/13_pool2d_fwd/pool2d_fwd_common.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #pragma once @@ -39,31 +39,35 @@ bool pool_test(bool do_verification, ck::index_t Wi, ck::index_t window_stride_h, ck::index_t window_stride_w, + ck::index_t window_dilation_h, + ck::index_t window_dilation_w, ck::index_t in_left_pad_h, ck::index_t in_left_pad_w, ck::index_t in_right_pad_h, ck::index_t in_right_pad_w) { using DevicePoolFwdInstance = - ck::tensor_operation::device::DevicePool2dFwd_Input_N_Hi_Wi_C_Output_N_Ho_Wo_C< - InDataType, // InDataType - OutDataType, // OutDataType - IndexDataType, // IndexDataType - ComputeDataType, // ComputeDataType - ReduceOpId, - OutputIndex, - 64, // BlockSize - 64, // ReduceMThreadClusterSize - 1, // ReduceKThreadClusterSize - 4, // ReduceMThreadSliceSize - 1, // ReduceKThreadSliceSize - 4>; // InSrcOutDstVectorSize - - const ck::index_t Ho = (Hi + in_left_pad_h + in_right_pad_h - Y) / window_stride_h + 1; - const ck::index_t Wo = (Wi + in_left_pad_w + in_right_pad_w - X) / window_stride_w + 1; + ck::tensor_operation::device::DevicePool2dFwd_NHWC_NHWC; // InSrcOutDstVectorSize + + const ck::index_t Ys = (Y - 1) * window_dilation_h + 1; + const ck::index_t Xs = (X - 1) * window_dilation_w + 1; + const ck::index_t Ho = (Hi + in_left_pad_h + in_right_pad_h - Ys) / window_stride_h + 1; + const ck::index_t Wo = (Wi + in_left_pad_w + in_right_pad_w - Xs) / window_stride_w + 1; const std::vector window_spatial_lengths{Y, X}; const std::vector window_strides{window_stride_h, window_stride_w}; + const std::vector window_dilations{window_dilation_h, window_dilation_w}; const std::vector input_left_pads{in_left_pad_h, in_left_pad_w}; const std::vector input_right_pads{in_right_pad_h, in_right_pad_w}; @@ -123,6 +127,7 @@ bool pool_test(bool do_verification, {C * Ho * Wo, 1, Wo * C, C}, {C * Ho * Wo, 1, Wo * C, C}, window_strides, + window_dilations, input_left_pads, input_right_pads, {2, 3}); @@ -144,8 +149,8 @@ bool pool_test(bool do_verification, float gb_per_sec = num_btype / 1.E6 / ave_time; - std::cout << "Perf: " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec << " GB/s" - << std::endl; + std::cout << "Perf: " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec + << " GB / s " << std::endl; bool pass = true; @@ -169,6 +174,7 @@ bool pool_test(bool do_verification, out_indices_n_c_ho_wo_host, window_spatial_lengths, window_strides, + window_dilations, input_left_pads, input_right_pads); diff --git a/example/13_pool2d_fwd/pool2d_fwd_fp16.cpp b/example/13_pool2d_fwd/pool2d_fwd_fp16.cpp index 20c3e4701756d61ee7332f9c6d5c4b0a2be9b8f6..d767e9224893067785f1ad5301b513061d0490b0 100644 --- a/example/13_pool2d_fwd/pool2d_fwd_fp16.cpp +++ b/example/13_pool2d_fwd/pool2d_fwd_fp16.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #include @@ -34,18 +34,20 @@ int main(int argc, char* argv[]) bool time_kernel; // Pool shape - ck::index_t N = 128; - ck::index_t C = 192; - ck::index_t Y = 3; - ck::index_t X = 3; - ck::index_t Hi = 71; - ck::index_t Wi = 71; - ck::index_t window_stride_h = 2; - ck::index_t window_stride_w = 2; - ck::index_t in_left_pad_h = 1; - ck::index_t in_left_pad_w = 1; - ck::index_t in_right_pad_h = 1; - ck::index_t in_right_pad_w = 1; + ck::index_t N = 128; + ck::index_t C = 192; + ck::index_t Y = 3; + ck::index_t X = 3; + ck::index_t Hi = 71; + ck::index_t Wi = 71; + ck::index_t window_stride_h = 2; + ck::index_t window_stride_w = 2; + ck::index_t window_dilation_h = 1; + ck::index_t window_dilation_w = 1; + ck::index_t in_left_pad_h = 1; + ck::index_t in_left_pad_w = 1; + ck::index_t in_right_pad_h = 1; + ck::index_t in_right_pad_w = 1; if(argc == 1) { @@ -59,31 +61,33 @@ int main(int argc, char* argv[]) init_method = std::stoi(argv[2]); time_kernel = static_cast(std::stoi(argv[3])); } - else if(argc == 16) + else if(argc == 18) { do_verification = std::stoi(argv[1]); init_method = std::stoi(argv[2]); time_kernel = static_cast(std::stoi(argv[3])); - N = std::stoi(argv[4]); - C = std::stoi(argv[5]); - Y = std::stoi(argv[6]); - X = std::stoi(argv[7]); - Hi = std::stoi(argv[8]); - Wi = std::stoi(argv[9]); - window_stride_h = std::stoi(argv[10]); - window_stride_w = std::stoi(argv[11]); - in_left_pad_h = std::stoi(argv[12]); - in_left_pad_w = std::stoi(argv[13]); - in_right_pad_h = std::stoi(argv[14]); - in_right_pad_w = std::stoi(argv[15]); + N = std::stoi(argv[4]); + C = std::stoi(argv[5]); + Y = std::stoi(argv[6]); + X = std::stoi(argv[7]); + Hi = std::stoi(argv[8]); + Wi = std::stoi(argv[9]); + window_stride_h = std::stoi(argv[10]); + window_stride_w = std::stoi(argv[11]); + window_dilation_h = std::stoi(argv[12]); + window_dilation_w = std::stoi(argv[13]); + in_left_pad_h = std::stoi(argv[14]); + in_left_pad_w = std::stoi(argv[15]); + in_right_pad_h = std::stoi(argv[16]); + in_right_pad_w = std::stoi(argv[17]); } else { printf("arg1: verification (0=no, 1=yes)\n"); printf("arg2: initialization (0=no init, 1=integer value, 2=decimal value)\n"); printf("arg3: time kernel (0=no, 1=yes)\n"); - printf("arg4 to 15: N, C, Y, X, Hi, Wi, Sy, Sx, LeftPy, LeftPx, RightPy, " + printf("arg4 to 15: N, C, Y, X, Hi, Wi, Sy, Sx, Dy, Dx, LeftPy, LeftPx, RightPy, " "RightPx\n"); exit(0); } @@ -107,6 +111,8 @@ int main(int argc, char* argv[]) Wi, window_stride_h, window_stride_w, + window_dilation_h, + window_dilation_w, in_left_pad_h, in_left_pad_w, in_right_pad_h, diff --git a/example/13_pool2d_fwd/pool2d_fwd_fp32.cpp b/example/13_pool2d_fwd/pool2d_fwd_fp32.cpp index 34ff6f435ee0d9d5b80b7900ee065320e2a6d989..2621500ef1ee7ab8698e2323a614f7e595c1ed4b 100644 --- a/example/13_pool2d_fwd/pool2d_fwd_fp32.cpp +++ b/example/13_pool2d_fwd/pool2d_fwd_fp32.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #include @@ -34,18 +34,20 @@ int main(int argc, char* argv[]) bool time_kernel; // Pool shape - ck::index_t N = 128; - ck::index_t C = 192; - ck::index_t Y = 3; - ck::index_t X = 3; - ck::index_t Hi = 71; - ck::index_t Wi = 71; - ck::index_t window_stride_h = 2; - ck::index_t window_stride_w = 2; - ck::index_t in_left_pad_h = 1; - ck::index_t in_left_pad_w = 1; - ck::index_t in_right_pad_h = 1; - ck::index_t in_right_pad_w = 1; + ck::index_t N = 128; + ck::index_t C = 192; + ck::index_t Y = 3; + ck::index_t X = 3; + ck::index_t Hi = 71; + ck::index_t Wi = 71; + ck::index_t window_stride_h = 2; + ck::index_t window_stride_w = 2; + ck::index_t window_dilation_h = 1; + ck::index_t window_dilation_w = 1; + ck::index_t in_left_pad_h = 1; + ck::index_t in_left_pad_w = 1; + ck::index_t in_right_pad_h = 1; + ck::index_t in_right_pad_w = 1; if(argc == 1) { @@ -59,31 +61,33 @@ int main(int argc, char* argv[]) init_method = std::stoi(argv[2]); time_kernel = static_cast(std::stoi(argv[3])); } - else if(argc == 16) + else if(argc == 18) { do_verification = std::stoi(argv[1]); init_method = std::stoi(argv[2]); time_kernel = static_cast(std::stoi(argv[3])); - N = std::stoi(argv[4]); - C = std::stoi(argv[5]); - Y = std::stoi(argv[6]); - X = std::stoi(argv[7]); - Hi = std::stoi(argv[8]); - Wi = std::stoi(argv[9]); - window_stride_h = std::stoi(argv[10]); - window_stride_w = std::stoi(argv[11]); - in_left_pad_h = std::stoi(argv[12]); - in_left_pad_w = std::stoi(argv[13]); - in_right_pad_h = std::stoi(argv[14]); - in_right_pad_w = std::stoi(argv[15]); + N = std::stoi(argv[4]); + C = std::stoi(argv[5]); + Y = std::stoi(argv[6]); + X = std::stoi(argv[7]); + Hi = std::stoi(argv[8]); + Wi = std::stoi(argv[9]); + window_stride_h = std::stoi(argv[10]); + window_stride_w = std::stoi(argv[11]); + window_dilation_h = std::stoi(argv[12]); + window_dilation_w = std::stoi(argv[13]); + in_left_pad_h = std::stoi(argv[14]); + in_left_pad_w = std::stoi(argv[15]); + in_right_pad_h = std::stoi(argv[16]); + in_right_pad_w = std::stoi(argv[17]); } else { printf("arg1: verification (0=no, 1=yes)\n"); printf("arg2: initialization (0=no init, 1=integer value, 2=decimal value)\n"); printf("arg3: time kernel (0=no, 1=yes)\n"); - printf("arg4 to 15: N, C, Y, X, Hi, Wi, Sy, Sx, LeftPy, LeftPx, RightPy, " + printf("arg4 to 15: N, C, Y, X, Hi, Wi, Sy, Sx, Dy, Dx, LeftPy, LeftPx, RightPy, " "RightPx\n"); exit(0); } @@ -107,6 +111,8 @@ int main(int argc, char* argv[]) Wi, window_stride_h, window_stride_w, + window_dilation_h, + window_dilation_w, in_left_pad_h, in_left_pad_w, in_right_pad_h, diff --git a/example/14_gemm_quantization/CMakeLists.txt b/example/14_gemm_quantization/CMakeLists.txt index 584333e7bbac840af33d468848f923a59ad4c933..3b3ad80dd826fef38f8d4ed36d504538ae3e263a 100644 --- a/example/14_gemm_quantization/CMakeLists.txt +++ b/example/14_gemm_quantization/CMakeLists.txt @@ -1,8 +1,17 @@ +if(DTYPES MATCHES "int8" OR NOT DEFINED DTYPES) # dlops -add_example_executable(example_gemm_dl_quantization_int8 gemm_dl_quantization_int8.cpp) +if(DL_KERNELS) + add_example_executable(example_gemm_dl_quantization_int8 gemm_dl_quantization_int8.cpp) +endif() # xdlops -if(GPU_TARGETS MATCHES "gfx908" OR GPU_TARGETS MATCHES "gfx90a" OR GPU_TARGETS MATCHES "gfx940") +list(APPEND gpu_list gfx908 gfx90a gfx940 gfx941 gfx942) +set(target 0) +foreach(gpu IN LISTS GPU_TARGETS) + if(gpu IN_LIST gpu_list AND target EQUAL 0) add_example_executable(example_gemm_xdl_bias_relu_quantization_int8 gemm_xdl_bias_relu_quantization_int8.cpp) add_example_executable(example_gemm_xdl_quantization_int8 gemm_xdl_quantization_int8.cpp) + set(target 1) + endif() +endforeach() endif() \ No newline at end of file diff --git a/example/14_gemm_quantization/gemm_dl_quantization_int8.cpp b/example/14_gemm_quantization/gemm_dl_quantization_int8.cpp index 044f3c166a3f062760782167f3b7b4177f1693bd..2585072dfed88f569f983cc3e563c9deef8036be 100644 --- a/example/14_gemm_quantization/gemm_dl_quantization_int8.cpp +++ b/example/14_gemm_quantization/gemm_dl_quantization_int8.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #include #include diff --git a/example/14_gemm_quantization/gemm_xdl_bias_relu_quantization_int8.cpp b/example/14_gemm_quantization/gemm_xdl_bias_relu_quantization_int8.cpp index d5f4e6f62c3eb8e58285e81446a5e1332c37c533..aa3e0116954e603ed5cd759820ea49184bb4cb8a 100644 --- a/example/14_gemm_quantization/gemm_xdl_bias_relu_quantization_int8.cpp +++ b/example/14_gemm_quantization/gemm_xdl_bias_relu_quantization_int8.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #include #include diff --git a/example/14_gemm_quantization/gemm_xdl_quantization_int8.cpp b/example/14_gemm_quantization/gemm_xdl_quantization_int8.cpp index 2371737382447d666d55633fbf3ac976c3017cf7..4b207df5c628976c54d8b9c0f53ea140d51c98e5 100644 --- a/example/14_gemm_quantization/gemm_xdl_quantization_int8.cpp +++ b/example/14_gemm_quantization/gemm_xdl_quantization_int8.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #include #include diff --git a/example/15_grouped_gemm/CMakeLists.txt b/example/15_grouped_gemm/CMakeLists.txt index 9df256c38d59f5ae48ca9539b5a3d637c9446c8f..2f880af3cfd5b8cd3cbeedd3eb1c43d60b8e47f8 100644 --- a/example/15_grouped_gemm/CMakeLists.txt +++ b/example/15_grouped_gemm/CMakeLists.txt @@ -1,21 +1,30 @@ add_custom_target(example_grouped_gemm_xdl) -add_example_executable(example_grouped_gemm_xdl_fp32 grouped_gemm_xdl_fp32.cpp) -add_example_executable(example_grouped_gemm_xdl_fp16 grouped_gemm_xdl_fp16.cpp) -add_example_executable(example_grouped_gemm_xdl_bfp16 grouped_gemm_xdl_bfp16.cpp) -add_example_executable(example_grouped_gemm_xdl_int8 grouped_gemm_xdl_int8.cpp) -add_example_executable(example_grouped_gemm_multiple_d_dl_fp16 grouped_gemm_multiple_d_dl_fp16.cpp) -add_example_executable(example_grouped_gemm_xdl_splitk_fp16 grouped_gemm_xdl_splitk_fp16.cpp) - - -add_dependencies(example_grouped_gemm_xdl - example_grouped_gemm_xdl_fp32 - example_grouped_gemm_xdl_fp16 - example_grouped_gemm_xdl_bfp16 - example_grouped_gemm_xdl_int8 - example_grouped_gemm_multiple_d_dl_fp16 - example_grouped_gemm_xdl_splitk_fp16) - +if(DTYPES MATCHES "fp32" OR NOT DEFINED DTYPES) + add_example_executable(example_grouped_gemm_xdl_fp32 grouped_gemm_xdl_fp32.cpp) + add_dependencies(example_grouped_gemm_xdl example_grouped_gemm_xdl_fp32) +endif() +if(DTYPES MATCHES "fp16" OR NOT DEFINED DTYPES) + add_example_executable(example_grouped_gemm_xdl_fp16 grouped_gemm_xdl_fp16.cpp) + add_example_executable(example_grouped_gemm_multiple_d_dl_fp16 grouped_gemm_multiple_d_dl_fp16.cpp) + add_example_executable(example_grouped_gemm_xdl_splitk_fp16 grouped_gemm_xdl_splitk_fp16.cpp) + add_example_executable(example_grouped_gemm_xdl_fixed_nk_fp16 grouped_gemm_xdl_fixed_nk_fp16.cpp) + add_example_executable(example_grouped_gemm_xdl_fixed_nk_bias_fp16 grouped_gemm_xdl_fixed_nk_bias_fp16.cpp) + add_dependencies(example_grouped_gemm_xdl + example_grouped_gemm_xdl_fp16 + example_grouped_gemm_multiple_d_dl_fp16 + example_grouped_gemm_xdl_splitk_fp16 + example_grouped_gemm_xdl_fixed_nk_fp16 + example_grouped_gemm_xdl_fixed_nk_bias_fp16) +endif() +if(DTYPES MATCHES "bf16" OR NOT DEFINED DTYPES) + add_example_executable(example_grouped_gemm_xdl_bfp16 grouped_gemm_xdl_bfp16.cpp) + add_dependencies(example_grouped_gemm_xdl example_grouped_gemm_xdl_bfp16) +endif() +if(DTYPES MATCHES "int8" OR NOT DEFINED DTYPES) + add_example_executable(example_grouped_gemm_xdl_int8 grouped_gemm_xdl_int8.cpp) + add_dependencies(example_grouped_gemm_xdl example_grouped_gemm_xdl_int8) +endif() if(USE_BITINT_EXTENSION_INT4) add_example_executable(example_grouped_gemm_xdl_int4 grouped_gemm_xdl_int4.cpp) add_dependencies(example_grouped_gemm_xdl example_grouped_gemm_xdl_int4) diff --git a/example/15_grouped_gemm/grouped_gemm_multiple_d_dl_fp16.cpp b/example/15_grouped_gemm/grouped_gemm_multiple_d_dl_fp16.cpp index a5c51ceb0cfbc1a19215c8a2313b12a3b4d5ec51..3e1f7f089371b17fb02ebca53b27b0860c7b9c72 100644 --- a/example/15_grouped_gemm/grouped_gemm_multiple_d_dl_fp16.cpp +++ b/example/15_grouped_gemm/grouped_gemm_multiple_d_dl_fp16.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #include #include diff --git a/example/15_grouped_gemm/grouped_gemm_xdl_bfp16.cpp b/example/15_grouped_gemm/grouped_gemm_xdl_bfp16.cpp index 05d572a1f532446bdc0ec43e59b1eeadfe4bd054..680cee1f814db7cef177e87e4d9c03ba745b8e26 100644 --- a/example/15_grouped_gemm/grouped_gemm_xdl_bfp16.cpp +++ b/example/15_grouped_gemm/grouped_gemm_xdl_bfp16.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #include #include diff --git a/example/15_grouped_gemm/grouped_gemm_xdl_fixed_nk_bias_fp16.cpp b/example/15_grouped_gemm/grouped_gemm_xdl_fixed_nk_bias_fp16.cpp new file mode 100644 index 0000000000000000000000000000000000000000..a193fc39ba637cbd41df4743e0157ca40c38407b --- /dev/null +++ b/example/15_grouped_gemm/grouped_gemm_xdl_fixed_nk_bias_fp16.cpp @@ -0,0 +1,353 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#include +#include +#include +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_grouped_gemm_xdl_fixed_nk.hpp" +#include "ck/tensor_operation/gpu/device/device_grouped_gemm.hpp" +#include "ck/tensor_operation/gpu/element/unary_element_wise_operation.hpp" + +#include "ck/library/utility/check_err.hpp" +#include "ck/library/utility/device_memory.hpp" +#include "ck/library/utility/host_tensor.hpp" +#include "ck/library/utility/host_tensor_generator.hpp" +#include "ck/library/utility/literals.hpp" +#include "ck/library/reference_tensor_operation/cpu/reference_gemm.hpp" + +template +using S = ck::Sequence; + +using F16 = ck::half_t; +using F32 = float; + +using Row = ck::tensor_layout::gemm::RowMajor; +using Col = ck::tensor_layout::gemm::ColumnMajor; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; +using Add = ck::tensor_operation::element_wise::Add; + +using ADataType = F16; +using BDataType = F16; +using AccDataType = F32; +using CShuffleDataType = F32; +using D0DataType = F32; +using DsDataType = ck::Tuple; +using EDataType = F32; + +using ALayout = Row; +using BLayout = Row; +using D0Layout = Row; +using DsLayout = ck::Tuple; +using ELayout = Row; + +using AElementOp = PassThrough; +using BElementOp = PassThrough; + +using CDEElementOp = Add; + +static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::MPadding; + +using DeviceGemmInstance = ck::tensor_operation::device::DeviceGroupedGemm_Xdl_Fixed_NK + // clang-format off +//######| ALayout| BLayout| DsLayout| ELayout| AData| BData| AccData| CShuffle| DsData| EData| A| B| CDE| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| +//######| | | | | Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| +//######| | | | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| +//######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + < ALayout, BLayout, DsLayout, ELayout, ADataType, BDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementOp, BElementOp, CDEElementOp, GemmDefault, 1, 128, 16, 128, 32, 8, 8, 16, 16, 1, 4, S<1, 4, 16, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 4, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 4, 8, 1, 1, 1, S<1, 16, 1, 8>, 4>; +// clang-format on + +struct ProblemSize final +{ + std::vector Ms; + std::vector Ns; + std::vector Ks; + + std::vector stride_As; + std::vector stride_Bs; + std::vector stride_Cs; + + ck::index_t group_count; +}; + +struct ExecutionConfig final +{ + bool do_verification = true; + int init_method = 1; + bool time_kernel = false; + int k_batch = 1; +}; + +bool run_grouped_gemm(const ProblemSize& problem_size, const ExecutionConfig& config) +{ + auto group_count = problem_size.group_count; + + // GEMM shape + std::vector gemm_descs; + + gemm_descs.reserve(group_count); + + int sum_of_m = 0; + + auto f_host_tensor_descriptor = + [](std::size_t row, std::size_t col, std::size_t stride, auto layout) { + using namespace ck::literals; + + if(std::is_same::value) + { + return HostTensorDescriptor({row, col}, {stride, 1_uz}); + } + else + { + return HostTensorDescriptor({row, col}, {1_uz, stride}); + } + }; + + std::vector> a_tensors; + std::vector> b_tensors; + std::vector> d0_tensors; + std::vector> c_host_tensors; + std::vector> c_device_tensors; + + a_tensors.reserve(group_count); + b_tensors.reserve(group_count); + d0_tensors.reserve(group_count); + c_host_tensors.reserve(group_count); + c_device_tensors.reserve(group_count); + + using DeviceMemPtr = std::unique_ptr; + + std::vector a_tensors_device, b_tensors_device, d0_tensors_device, + c_tensors_device; + + a_tensors_device.reserve(group_count); + b_tensors_device.reserve(group_count); + d0_tensors_device.reserve(group_count); + c_tensors_device.reserve(group_count); + + std::size_t flop = 0, num_btype = 0; + + for(int i = 0; i < group_count; i++) + { + sum_of_m += problem_size.Ms[i]; + a_tensors.push_back(Tensor(f_host_tensor_descriptor( + problem_size.Ms[i], problem_size.Ks[i], problem_size.stride_As[i], ALayout{}))); + b_tensors.push_back(Tensor(f_host_tensor_descriptor( + problem_size.Ks[i], problem_size.Ns[i], problem_size.stride_Bs[i], BLayout{}))); + d0_tensors.push_back(Tensor( + f_host_tensor_descriptor(problem_size.Ms[i], problem_size.Ns[i], 0, ELayout{}))); + c_host_tensors.push_back(Tensor(f_host_tensor_descriptor( + problem_size.Ms[i], problem_size.Ns[i], problem_size.stride_Cs[i], ELayout{}))); + c_device_tensors.push_back(Tensor(f_host_tensor_descriptor( + problem_size.Ms[i], problem_size.Ns[i], problem_size.stride_Cs[i], ELayout{}))); + std::cout << "gemm[" << i << "] a_m_k: " << a_tensors[i].mDesc + << " b_k_n: " << b_tensors[i].mDesc << " d_m_n: " << d0_tensors[i].mDesc + << " c_m_n: " << c_device_tensors[i].mDesc << std::endl; + + flop += std::size_t(2) * problem_size.Ms[i] * problem_size.Ks[i] * problem_size.Ns[i]; + num_btype += sizeof(ADataType) * a_tensors[i].mDesc.GetElementSize() + + sizeof(BDataType) * b_tensors[i].mDesc.GetElementSize() + + sizeof(D0DataType) * d0_tensors[i].mDesc.GetElementSize() + + sizeof(EDataType) * c_device_tensors[i].mDesc.GetElementSize(); + + switch(config.init_method) + { + case 0: break; + case 1: + a_tensors[i].GenerateTensorValue(GeneratorTensor_2{-5, 5}); + b_tensors[i].GenerateTensorValue(GeneratorTensor_2{-5, 5}); + break; + case 2: + a_tensors[i].GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); + b_tensors[i].GenerateTensorValue(GeneratorTensor_3{-0.5, 0.5}); + break; + default: + a_tensors[i].GenerateTensorValue(GeneratorTensor_Sequential<0>{}); + b_tensors[i].GenerateTensorValue(GeneratorTensor_Sequential<1>{}); + } + + d0_tensors[i].GenerateTensorValue(GeneratorTensor_Sequential<1>{}); + } + + using GroupedGemmKernelArgument = ck::tensor_operation::device::GroupedGemmKernelArgument<1>; + + std::vector grouped_gemm_kernel_args_; + grouped_gemm_kernel_args_.reserve(group_count); + + for(int i = 0; i < group_count; i++) + { + a_tensors_device.emplace_back( + std::make_unique(sizeof(ADataType) * sum_of_m * problem_size.Ks[i])); + + b_tensors_device.emplace_back(std::make_unique( + sizeof(BDataType) * problem_size.Ns[i] * problem_size.Ks[i])); + + d0_tensors_device.emplace_back( + std::make_unique(sizeof(D0DataType) * problem_size.Ns[i])); + + c_tensors_device.emplace_back( + std::make_unique(sizeof(EDataType) * sum_of_m * problem_size.Ns[i])); + + a_tensors_device[i]->ToDevice(a_tensors[i].mData.data(), + a_tensors[i].mDesc.GetElementSpaceSize() * sizeof(ADataType)); + b_tensors_device[i]->ToDevice(b_tensors[i].mData.data(), + b_tensors[i].mDesc.GetElementSpaceSize() * sizeof(BDataType)); + d0_tensors_device[i]->ToDevice(d0_tensors[i].mData.data()); + c_tensors_device[i]->SetZero(); + + gemm_descs.push_back({sum_of_m, + problem_size.Ns[i], + problem_size.Ks[i], + 1, + problem_size.stride_Bs[i], + 1, + {0}}); + + grouped_gemm_kernel_args_.push_back( + {a_tensors_device[i]->GetDeviceBuffer(), + b_tensors_device[i]->GetDeviceBuffer(), + std::array{d0_tensors_device[i]->GetDeviceBuffer()}, + c_tensors_device[i]->GetDeviceBuffer(), + problem_size.Ms[i], + problem_size.Ns[i], + problem_size.Ks[i], + problem_size.stride_As[i], + problem_size.stride_Bs[i], + std::array{0}, + problem_size.stride_Cs[i]}); + } + + auto a_element_op = AElementOp{}; + auto b_element_op = BElementOp{}; + auto cde_element_op = CDEElementOp{}; + + auto gemm = DeviceGemmInstance{}; + auto invoker = gemm.MakeInvoker(); + + std::vector p_As = {}; + std::vector p_Bs = {}; + std::vector> p_Ds = {}; + std::vector p_Cs = {}; + + // do GEMM + auto argument = gemm.MakeArgument( + p_As, p_Bs, p_Ds, p_Cs, gemm_descs, a_element_op, b_element_op, cde_element_op); + + if(!gemm.IsSupportedArgument(argument)) + { + throw std::runtime_error( + "wrong! device_gemm with the specified compilation parameters does " + "not support this GEMM problem"); + } + + DeviceMem gemm_workspace_dev(gemm.GetWorkSpaceSize(&argument)); + gemm.SetWorkSpacePointer(&argument, gemm_workspace_dev.GetDeviceBuffer()); + + DeviceMem gemm_kernel_args_dev(gemm.GetDeviceKernelArgSize(&argument)); + hip_check_error(hipMemcpy(gemm_kernel_args_dev.GetDeviceBuffer(), + grouped_gemm_kernel_args_.data(), + gemm.GetDeviceKernelArgSize(&argument), + hipMemcpyHostToDevice)); + + gemm.SetDeviceKernelArgs(argument, gemm_kernel_args_dev.GetDeviceBuffer()); + gemm.SetKBatch(argument, config.k_batch); + + invoker.Run(argument, StreamConfig{nullptr, false}); + + if(config.time_kernel) + { + float ave_time = invoker.Run(argument, StreamConfig{nullptr, config.time_kernel}); + float tflops = static_cast(flop) / 1.E9 / ave_time; + float gb_per_sec = num_btype / 1.E6 / ave_time; + + std::cout << "Perf: " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec + << " GB/s, " << gemm.GetTypeString() << std::endl; + } + + bool pass = true; + if(config.do_verification) + { + using ReferenceGemmInstance = ck::tensor_operation::host::ReferenceGemm; + + for(std::size_t i = 0; i < gemm_descs.size(); i++) + { + c_tensors_device[i]->FromDevice(c_device_tensors[i].mData.data(), + c_device_tensors[i].mDesc.GetElementSize() * + sizeof(EDataType)); + + auto ref_gemm = ReferenceGemmInstance{}; + auto ref_invoker = ref_gemm.MakeInvoker(); + + auto ref_argument = ref_gemm.MakeArgument(a_tensors[i], + b_tensors[i], + c_host_tensors[i], + a_element_op, + b_element_op, + PassThrough{}); + + ref_invoker.Run(ref_argument); + + for(int m = 0; m < problem_size.Ms[i]; ++m) + { + for(int n = 0; n < problem_size.Ns[i]; ++n) + { + cde_element_op( + c_host_tensors[i](m, n), c_host_tensors[i](m, n), d0_tensors[i](m, n)); + } + } + + pass &= ck::utils::check_err(c_device_tensors[i], c_host_tensors[i]); + } + } + + return pass; +} + +int main(int argc, char* argv[]) +{ + ProblemSize problem_size; + ExecutionConfig config; + + problem_size.group_count = 16; + + problem_size.Ms = {0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0}; + + for(int i = 0; i < problem_size.group_count; i++) + { + problem_size.Ns.push_back(768); + problem_size.Ks.push_back(4608); + + problem_size.stride_As.push_back(problem_size.Ks[i]); + problem_size.stride_Bs.push_back(problem_size.Ns[i]); + problem_size.stride_Cs.push_back(problem_size.Ns[i]); + } + + if(argc == 5) + { + config.do_verification = std::stoi(argv[1]); + config.init_method = std::stoi(argv[2]); + config.time_kernel = std::stoi(argv[3]); + config.k_batch = std::stoi(argv[4]); + } + else + { + printf("arg1: verification (0=no, 1=yes)\n"); + printf("arg2: initialization (0=no init, 1=integer value, 2=decimal value)\n"); + printf("arg3: time kernel (0=n0, 1=yes)\n"); + printf("arg4: k_batch (>0)\n"); + exit(0); + } + + return !run_grouped_gemm(problem_size, config); +} diff --git a/example/15_grouped_gemm/grouped_gemm_xdl_fixed_nk_fp16.cpp b/example/15_grouped_gemm/grouped_gemm_xdl_fixed_nk_fp16.cpp new file mode 100644 index 0000000000000000000000000000000000000000..89d4789c12a8ebe85a9344316205d06660a58984 --- /dev/null +++ b/example/15_grouped_gemm/grouped_gemm_xdl_fixed_nk_fp16.cpp @@ -0,0 +1,329 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#include +#include +#include +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_grouped_gemm_xdl_fixed_nk.hpp" +#include "ck/tensor_operation/gpu/device/device_grouped_gemm.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" + +#include "ck/library/utility/check_err.hpp" +#include "ck/library/utility/device_memory.hpp" +#include "ck/library/utility/host_tensor.hpp" +#include "ck/library/utility/host_tensor_generator.hpp" +#include "ck/library/utility/literals.hpp" +#include "ck/library/reference_tensor_operation/cpu/reference_gemm.hpp" + +template +using S = ck::Sequence; + +using F16 = ck::half_t; +using F32 = float; + +using Row = ck::tensor_layout::gemm::RowMajor; +using Col = ck::tensor_layout::gemm::ColumnMajor; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; + +using ADataType = F16; +using BDataType = F16; +using AccDataType = F32; +using CShuffleDataType = F32; +using DsDataType = ck::Tuple<>; +using EDataType = F32; + +using ALayout = Row; +using BLayout = Col; +using DsLayout = ck::Tuple<>; +using ELayout = Row; + +using AElementOp = PassThrough; +using BElementOp = PassThrough; +using CDEElementOp = PassThrough; + +static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::MNPadding; + +using DeviceGemmInstance = ck::tensor_operation::device::DeviceGroupedGemm_Xdl_Fixed_NK + // clang-format off +//######| ALayout| BLayout| DsLayout| ELayout| AData| BData| AccData| CShuffle| DsData| EData| A| B| CDE| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| +//######| | | | | Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| +//######| | | | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| +//######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + < ALayout, BLayout, DsLayout, ELayout, ADataType, BDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementOp, BElementOp, CDEElementOp, GemmDefault, 1, 256, 64, 128, 32, 8, 8, 32, 32, 1, 2, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 4>; +// clang-format on + +struct ProblemSize final +{ + std::vector Ms; + std::vector Ns; + std::vector Ks; + + std::vector stride_As; + std::vector stride_Bs; + std::vector stride_Cs; + + ck::index_t group_count; +}; + +struct ExecutionConfig final +{ + bool do_verification = true; + int init_method = 1; + int k_batch = 1; + bool time_kernel = false; +}; + +bool run_grouped_gemm(const ProblemSize& problem_size, const ExecutionConfig& config) +{ + auto group_count = problem_size.group_count; + + // GEMM shape + std::vector gemm_descs; + std::vector p_Cs; + + gemm_descs.reserve(group_count); + + int sum_of_m = 0; + + auto f_host_tensor_descriptor = + [](std::size_t row, std::size_t col, std::size_t stride, auto layout) { + using namespace ck::literals; + + if(std::is_same::value) + { + return HostTensorDescriptor({row, col}, {stride, 1_uz}); + } + else + { + return HostTensorDescriptor({row, col}, {1_uz, stride}); + } + }; + + std::vector> a_tensors; + std::vector> b_tensors; + std::vector> c_host_tensors; + std::vector> c_device_tensors; + + a_tensors.reserve(group_count); + b_tensors.reserve(group_count); + c_host_tensors.reserve(group_count); + c_device_tensors.reserve(group_count); + + using DeviceMemPtr = std::unique_ptr; + + std::vector a_tensors_device, b_tensors_device, c_tensors_device; + + a_tensors_device.reserve(group_count); + b_tensors_device.reserve(group_count); + c_tensors_device.reserve(group_count); + + std::size_t flop = 0, num_btype = 0; + + for(int i = 0; i < group_count; i++) + { + sum_of_m += problem_size.Ms[i]; + a_tensors.push_back(Tensor(f_host_tensor_descriptor( + problem_size.Ms[i], problem_size.Ks[i], problem_size.stride_As[i], ALayout{}))); + b_tensors.push_back(Tensor(f_host_tensor_descriptor( + problem_size.Ks[i], problem_size.Ns[i], problem_size.stride_Bs[i], BLayout{}))); + c_host_tensors.push_back(Tensor(f_host_tensor_descriptor( + problem_size.Ms[i], problem_size.Ns[i], problem_size.stride_Cs[i], ELayout{}))); + c_device_tensors.push_back(Tensor(f_host_tensor_descriptor( + problem_size.Ms[i], problem_size.Ns[i], problem_size.stride_Cs[i], ELayout{}))); + std::cout << "gemm[" << i << "] a_m_k: " << a_tensors[i].mDesc + << " b_k_n: " << b_tensors[i].mDesc << " c_m_n: " << c_device_tensors[i].mDesc + << std::endl; + + flop += std::size_t(2) * problem_size.Ms[i] * problem_size.Ks[i] * problem_size.Ns[i]; + num_btype += sizeof(ADataType) * a_tensors[i].mDesc.GetElementSize() + + sizeof(BDataType) * b_tensors[i].mDesc.GetElementSize() + + sizeof(EDataType) * c_device_tensors[i].mDesc.GetElementSize(); + + switch(config.init_method) + { + case 0: break; + case 1: + a_tensors[i].GenerateTensorValue(GeneratorTensor_2{-5, 5}); + b_tensors[i].GenerateTensorValue(GeneratorTensor_2{-5, 5}); + break; + case 2: + a_tensors[i].GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); + b_tensors[i].GenerateTensorValue(GeneratorTensor_3{-0.5, 0.5}); + break; + default: + a_tensors[i].GenerateTensorValue(GeneratorTensor_Sequential<0>{}); + b_tensors[i].GenerateTensorValue(GeneratorTensor_Sequential<1>{}); + } + } + + using GroupedGemmKernelArgument = ck::tensor_operation::device::GroupedGemmKernelArgument<>; + + std::vector grouped_gemm_kernel_args_; + grouped_gemm_kernel_args_.reserve(group_count); + + for(int i = 0; i < group_count; i++) + { + a_tensors_device.emplace_back( + std::make_unique(sizeof(ADataType) * sum_of_m * problem_size.Ks[i])); + + b_tensors_device.emplace_back(std::make_unique( + sizeof(BDataType) * problem_size.Ns[i] * problem_size.Ks[i])); + + c_tensors_device.emplace_back( + std::make_unique(sizeof(EDataType) * sum_of_m * problem_size.Ns[i])); + + a_tensors_device[i]->ToDevice(a_tensors[i].mData.data(), + a_tensors[i].mDesc.GetElementSpaceSize() * sizeof(ADataType)); + b_tensors_device[i]->ToDevice(b_tensors[i].mData.data(), + b_tensors[i].mDesc.GetElementSpaceSize() * sizeof(BDataType)); + c_tensors_device[i]->SetZero(); + + p_Cs.push_back(c_tensors_device[i]->GetDeviceBuffer()); + + gemm_descs.push_back({sum_of_m, + problem_size.Ns[i], + problem_size.Ks[i], + 1, + problem_size.stride_Bs[i], + 1, + {}}); + + grouped_gemm_kernel_args_.push_back({a_tensors_device[i]->GetDeviceBuffer(), + b_tensors_device[i]->GetDeviceBuffer(), + {}, + c_tensors_device[i]->GetDeviceBuffer(), + problem_size.Ms[i], + problem_size.Ns[i], + problem_size.Ks[i], + problem_size.stride_As[i], + problem_size.stride_Bs[i], + {}, + problem_size.stride_Cs[i]}); + } + + auto a_element_op = AElementOp{}; + auto b_element_op = BElementOp{}; + auto c_element_op = CDEElementOp{}; + + auto gemm = DeviceGemmInstance{}; + auto invoker = gemm.MakeInvoker(); + + std::vector p_As = {}; + std::vector p_Bs = {}; + std::vector> p_Ds = {}; + + // do GEMM + auto argument = gemm.MakeArgument( + p_As, p_Bs, p_Ds, p_Cs, gemm_descs, a_element_op, b_element_op, c_element_op); + + DeviceMem gemm_arg_dev_mem(gemm.GetDeviceKernelArgSize(&argument)); + DeviceMem gemm_workspace_dev(gemm.GetWorkSpaceSize(&argument)); + + gemm.SetWorkSpacePointer(&argument, gemm_workspace_dev.GetDeviceBuffer()); + + hip_check_error(hipMemcpy(gemm_arg_dev_mem.GetDeviceBuffer(), + grouped_gemm_kernel_args_.data(), + gemm.GetDeviceKernelArgSize(&argument), + hipMemcpyHostToDevice)); + + if(!gemm.IsSupportedArgument(argument)) + { + throw std::runtime_error( + "wrong! device_gemm with the specified compilation parameters does " + "not support this GEMM problem"); + } + + gemm.SetDeviceKernelArgs(argument, gemm_arg_dev_mem.GetDeviceBuffer()); + gemm.SetKBatch(argument, config.k_batch); + + invoker.Run(argument, StreamConfig{nullptr, false}); + + if(config.time_kernel) + { + float ave_time = invoker.Run(argument, StreamConfig{nullptr, config.time_kernel}); + float tflops = static_cast(flop) / 1.E9 / ave_time; + float gb_per_sec = num_btype / 1.E6 / ave_time; + + std::cout << "Perf: " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec + << " GB/s, " << gemm.GetTypeString() << std::endl; + } + + bool pass = true; + if(config.do_verification) + { + using ReferenceGemmInstance = ck::tensor_operation::host::ReferenceGemm; + + for(std::size_t i = 0; i < gemm_descs.size(); i++) + { + c_tensors_device[i]->FromDevice(c_device_tensors[i].mData.data(), + c_device_tensors[i].mDesc.GetElementSize() * + sizeof(EDataType)); + auto ref_gemm = ReferenceGemmInstance{}; + auto ref_invoker = ref_gemm.MakeInvoker(); + + auto ref_argument = ref_gemm.MakeArgument(a_tensors[i], + b_tensors[i], + c_host_tensors[i], + a_element_op, + b_element_op, + c_element_op); + + ref_invoker.Run(ref_argument); + + pass &= ck::utils::check_err(c_device_tensors[i], c_host_tensors[i]); + } + } + + return pass; +} + +int main(int argc, char* argv[]) +{ + ProblemSize problem_size; + ExecutionConfig config; + + problem_size.group_count = 16; + + problem_size.Ms = { + 167, 183, 177, 181, 153, 139, 156, 173, 163, 150, 204, 184, 168, 156, 168, 148}; + + for(int i = 0; i < problem_size.group_count; i++) + { + problem_size.Ns.push_back(768); + problem_size.Ks.push_back(4608); + + problem_size.stride_As.push_back(problem_size.Ks[i]); + problem_size.stride_Bs.push_back(problem_size.Ks[i]); + problem_size.stride_Cs.push_back(problem_size.Ns[i]); + } + + if(argc == 5) + { + config.do_verification = std::stoi(argv[1]); + config.init_method = std::stoi(argv[2]); + config.time_kernel = std::stoi(argv[3]); + config.k_batch = std::stoi(argv[4]); + } + else + { + printf("arg1: verification (0=no, 1=yes)\n"); + printf("arg2: initialization (0=no init, 1=integer value, 2=decimal value)\n"); + printf("arg3: time kernel (0=n0, 1=yes)\n"); + printf("arg4: k_batch (> 0)\n"); + exit(0); + } + + return !run_grouped_gemm(problem_size, config); +} diff --git a/example/15_grouped_gemm/grouped_gemm_xdl_fp16.cpp b/example/15_grouped_gemm/grouped_gemm_xdl_fp16.cpp index 3f78dafa8977b1b57592500138dd8f034b294efd..90a12bc1ddbf27e04a55de23e3fc6845868dac43 100644 --- a/example/15_grouped_gemm/grouped_gemm_xdl_fp16.cpp +++ b/example/15_grouped_gemm/grouped_gemm_xdl_fp16.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #include #include diff --git a/example/15_grouped_gemm/grouped_gemm_xdl_fp32.cpp b/example/15_grouped_gemm/grouped_gemm_xdl_fp32.cpp index fd93bb5f87d49a96da8973c3f862b8b48bdeb180..28b0fcd0cea6d6aed46ffeabc35c4c5ebc72b864 100644 --- a/example/15_grouped_gemm/grouped_gemm_xdl_fp32.cpp +++ b/example/15_grouped_gemm/grouped_gemm_xdl_fp32.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #include #include diff --git a/example/15_grouped_gemm/grouped_gemm_xdl_int4.cpp b/example/15_grouped_gemm/grouped_gemm_xdl_int4.cpp index faf41bbf0bba8e06fee9982a2f158706401ca45d..60c4a71a35146c31d9b79703faf5d3422942d817 100644 --- a/example/15_grouped_gemm/grouped_gemm_xdl_int4.cpp +++ b/example/15_grouped_gemm/grouped_gemm_xdl_int4.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #include #include diff --git a/example/15_grouped_gemm/grouped_gemm_xdl_int8.cpp b/example/15_grouped_gemm/grouped_gemm_xdl_int8.cpp index 7cb09778c521c9dfaed38c39ee5cb7fcff7bbc26..0c96ef56d3c144f19dcb4c7101a8d481d8960ee6 100644 --- a/example/15_grouped_gemm/grouped_gemm_xdl_int8.cpp +++ b/example/15_grouped_gemm/grouped_gemm_xdl_int8.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #include #include diff --git a/example/15_grouped_gemm/grouped_gemm_xdl_splitk_fp16.cpp b/example/15_grouped_gemm/grouped_gemm_xdl_splitk_fp16.cpp index a89937b2e2e95c6c8f03587e429bd72129825a48..743ab96be6291e617cb012957865a6897bd61a91 100644 --- a/example/15_grouped_gemm/grouped_gemm_xdl_splitk_fp16.cpp +++ b/example/15_grouped_gemm/grouped_gemm_xdl_splitk_fp16.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #include #include diff --git a/example/16_gemm_multi_d_multi_reduces/CMakeLists.txt b/example/16_gemm_multi_d_multi_reduces/CMakeLists.txt index 0f1ca777c5e1e8ffffd6e51505fa0d0dbd653070..00786d34a3af79a6d896e58b4fa12d9bbf8a7aa3 100644 --- a/example/16_gemm_multi_d_multi_reduces/CMakeLists.txt +++ b/example/16_gemm_multi_d_multi_reduces/CMakeLists.txt @@ -1,35 +1,38 @@ -if(GPU_TARGETS MATCHES "gfx908" OR GPU_TARGETS MATCHES "gfx90a" OR GPU_TARGETS MATCHES "gfx940") +list(APPEND gpu_list gfx908 gfx90a gfx940 gfx941 gfx942) +set(target 0) +foreach(gpu IN LISTS GPU_TARGETS) + if(gpu IN_LIST gpu_list AND target EQUAL 0) add_custom_target(example_gemm_reduce_xdl) add_custom_target(example_gemm_reduce_xdl_max) add_custom_target(example_gemm_reduce_xdl_mean_meansquare) add_custom_target(example_gemm_add_add_mean_meansquare_xdl) - - add_example_executable(example_gemm_max_xdl_fp16 gemm_max_xdl_fp16.cpp) - add_example_executable(example_gemm_max_xdl_int8 gemm_max_xdl_int8.cpp) - add_example_executable(example_gemm_max_xdl_fp32 gemm_max_xdl_fp32.cpp) - add_example_executable(example_gemm_max_xdl_bf16 gemm_max_xdl_bf16.cpp) - - add_example_executable(example_gemm_add_add_mean_meansquare_xdl_fp16 gemm_add_add_mean_meansquare_xdl_fp16.cpp) - - add_example_executable(example_gemm_mean_meansquare_xdl_fp16 gemm_mean_meansquare_xdl_fp16.cpp) - add_example_executable(example_gemm_mean_meansquare_xdl_fp32 gemm_mean_meansquare_xdl_fp32.cpp) - add_example_executable(example_gemm_mean_meansquare_xdl_bf16 gemm_mean_meansquare_xdl_bf16.cpp) - add_example_executable(example_gemm_add_addsquare_xdl_int8 gemm_add_addsquare_xdl_int8.cpp) - - add_dependencies(example_gemm_reduce_xdl_max - example_gemm_max_xdl_bf16 - example_gemm_max_xdl_fp16 - example_gemm_max_xdl_fp32 - example_gemm_max_xdl_int8) - - add_dependencies(example_gemm_reduce_xdl_mean_meansquare - example_gemm_mean_meansquare_xdl_fp16 - example_gemm_mean_meansquare_xdl_fp32 - example_gemm_mean_meansquare_xdl_bf16 - example_gemm_add_addsquare_xdl_int8) - - add_dependencies(example_gemm_add_add_mean_meansquare_xdl example_gemm_add_add_mean_meansquare_xdl_fp16) - + if(DTYPES MATCHES "fp16" OR NOT DEFINED DTYPES) + add_example_executable(example_gemm_max_xdl_fp16 gemm_max_xdl_fp16.cpp) + add_example_executable(example_gemm_add_add_mean_meansquare_xdl_fp16 gemm_add_add_mean_meansquare_xdl_fp16.cpp) + add_example_executable(example_gemm_mean_meansquare_xdl_fp16 gemm_mean_meansquare_xdl_fp16.cpp) + add_dependencies(example_gemm_reduce_xdl_max example_gemm_max_xdl_fp16) + add_dependencies(example_gemm_add_add_mean_meansquare_xdl example_gemm_add_add_mean_meansquare_xdl_fp16) + add_dependencies(example_gemm_reduce_xdl_mean_meansquare example_gemm_mean_meansquare_xdl_fp16) + endif() + if(DTYPES MATCHES "int8" OR NOT DEFINED DTYPES) + add_example_executable(example_gemm_max_xdl_int8 gemm_max_xdl_int8.cpp) + add_example_executable(example_gemm_add_addsquare_xdl_int8 gemm_add_addsquare_xdl_int8.cpp) + add_dependencies(example_gemm_reduce_xdl_max example_gemm_max_xdl_int8) + add_dependencies(example_gemm_reduce_xdl_mean_meansquare example_gemm_add_addsquare_xdl_int8) + endif() + if(DTYPES MATCHES "fp32" OR NOT DEFINED DTYPES) + add_example_executable(example_gemm_max_xdl_fp32 gemm_max_xdl_fp32.cpp) + add_example_executable(example_gemm_mean_meansquare_xdl_fp32 gemm_mean_meansquare_xdl_fp32.cpp) + add_dependencies(example_gemm_reduce_xdl_max example_gemm_max_xdl_fp32) + add_dependencies(example_gemm_reduce_xdl_mean_meansquare example_gemm_mean_meansquare_xdl_fp32) + endif() + if(DTYPES MATCHES "bf16" OR NOT DEFINED DTYPES) + add_example_executable(example_gemm_max_xdl_bf16 gemm_max_xdl_bf16.cpp) + add_example_executable(example_gemm_mean_meansquare_xdl_bf16 gemm_mean_meansquare_xdl_bf16.cpp) + add_dependencies(example_gemm_reduce_xdl_max example_gemm_max_xdl_bf16) + add_dependencies(example_gemm_reduce_xdl_mean_meansquare example_gemm_mean_meansquare_xdl_bf16) + endif() + add_dependencies(example_gemm_reduce_xdl example_gemm_reduce_xdl_mean_meansquare example_gemm_reduce_xdl_max @@ -39,4 +42,6 @@ if(GPU_TARGETS MATCHES "gfx908" OR GPU_TARGETS MATCHES "gfx90a" OR GPU_TARGETS M add_example_executable(example_gemm_max_xdl_int4 gemm_max_xdl_int4.cpp) add_dependencies(example_gemm_reduce_xdl_max example_gemm_max_xdl_int4) endif() -endif() + set(target 1) + endif() +endforeach() diff --git a/example/16_gemm_multi_d_multi_reduces/gemm_add_add_mean_meansquare_xdl_fp16.cpp b/example/16_gemm_multi_d_multi_reduces/gemm_add_add_mean_meansquare_xdl_fp16.cpp index eb3832a668ff9829ee76e87e6ea8988b40e6980b..2f6533d4481a2f339417a19f0f290f7fd306ee9c 100644 --- a/example/16_gemm_multi_d_multi_reduces/gemm_add_add_mean_meansquare_xdl_fp16.cpp +++ b/example/16_gemm_multi_d_multi_reduces/gemm_add_add_mean_meansquare_xdl_fp16.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #include #include diff --git a/example/16_gemm_multi_d_multi_reduces/gemm_add_addsquare_xdl_int8.cpp b/example/16_gemm_multi_d_multi_reduces/gemm_add_addsquare_xdl_int8.cpp index e1248002f751abce96072ee64283ca00022a8713..b28e7f85d3137915a2fe438b16aef5447b2d4d24 100644 --- a/example/16_gemm_multi_d_multi_reduces/gemm_add_addsquare_xdl_int8.cpp +++ b/example/16_gemm_multi_d_multi_reduces/gemm_add_addsquare_xdl_int8.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #include "gemm_reduce_xdl_common.hpp" diff --git a/example/16_gemm_multi_d_multi_reduces/gemm_max_xdl_bf16.cpp b/example/16_gemm_multi_d_multi_reduces/gemm_max_xdl_bf16.cpp index c2feffeb8953df030ad2e30cb942bec529b2cac1..b30ce2c48ad32d99dc5c8a4aaf7a3ee9ea774baf 100644 --- a/example/16_gemm_multi_d_multi_reduces/gemm_max_xdl_bf16.cpp +++ b/example/16_gemm_multi_d_multi_reduces/gemm_max_xdl_bf16.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #include "gemm_reduce_xdl_common.hpp" diff --git a/example/16_gemm_multi_d_multi_reduces/gemm_max_xdl_fp16.cpp b/example/16_gemm_multi_d_multi_reduces/gemm_max_xdl_fp16.cpp index 363390add3e99c96ef3891bec995cda5f8c3e24b..31e2efd6f635e30e569d97f73b4a5967b4909b82 100644 --- a/example/16_gemm_multi_d_multi_reduces/gemm_max_xdl_fp16.cpp +++ b/example/16_gemm_multi_d_multi_reduces/gemm_max_xdl_fp16.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #include "gemm_reduce_xdl_common.hpp" diff --git a/example/16_gemm_multi_d_multi_reduces/gemm_max_xdl_fp32.cpp b/example/16_gemm_multi_d_multi_reduces/gemm_max_xdl_fp32.cpp index de6b7eb480b30d4347018da94ec518ff29fdbeec..d3c7c1d99c06a222bcefae42d94cdbb6fc8e1239 100644 --- a/example/16_gemm_multi_d_multi_reduces/gemm_max_xdl_fp32.cpp +++ b/example/16_gemm_multi_d_multi_reduces/gemm_max_xdl_fp32.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #include "gemm_reduce_xdl_common.hpp" diff --git a/example/16_gemm_multi_d_multi_reduces/gemm_max_xdl_int4.cpp b/example/16_gemm_multi_d_multi_reduces/gemm_max_xdl_int4.cpp index 9666fc6622cdf3ab7cbab7071ebffd69d18b7668..9a4a6bc6e11a657b9d1eeaa35dd92552e397130d 100644 --- a/example/16_gemm_multi_d_multi_reduces/gemm_max_xdl_int4.cpp +++ b/example/16_gemm_multi_d_multi_reduces/gemm_max_xdl_int4.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #include "gemm_reduce_xdl_common.hpp" diff --git a/example/16_gemm_multi_d_multi_reduces/gemm_max_xdl_int8.cpp b/example/16_gemm_multi_d_multi_reduces/gemm_max_xdl_int8.cpp index 00e0b767a45dd979868b81635a79600d4695c779..1a8457a8bf87cdc82e687246014d2cba73065b8f 100644 --- a/example/16_gemm_multi_d_multi_reduces/gemm_max_xdl_int8.cpp +++ b/example/16_gemm_multi_d_multi_reduces/gemm_max_xdl_int8.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #include "gemm_reduce_xdl_common.hpp" diff --git a/example/16_gemm_multi_d_multi_reduces/gemm_mean_meansquare_xdl_bf16.cpp b/example/16_gemm_multi_d_multi_reduces/gemm_mean_meansquare_xdl_bf16.cpp index 652c0e6ea6d2bbe3321de2deefb2c186d7252023..5c2706c79ace02e72f2956270c322c4b00782ebf 100644 --- a/example/16_gemm_multi_d_multi_reduces/gemm_mean_meansquare_xdl_bf16.cpp +++ b/example/16_gemm_multi_d_multi_reduces/gemm_mean_meansquare_xdl_bf16.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #include "gemm_reduce_xdl_common.hpp" diff --git a/example/16_gemm_multi_d_multi_reduces/gemm_mean_meansquare_xdl_fp16.cpp b/example/16_gemm_multi_d_multi_reduces/gemm_mean_meansquare_xdl_fp16.cpp index 7eee24fed83988698e0e6d46f830f3a9dfdb0257..c119e243702a27df9bba2d6cf3f791f523c98d7d 100644 --- a/example/16_gemm_multi_d_multi_reduces/gemm_mean_meansquare_xdl_fp16.cpp +++ b/example/16_gemm_multi_d_multi_reduces/gemm_mean_meansquare_xdl_fp16.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #include "gemm_reduce_xdl_common.hpp" diff --git a/example/16_gemm_multi_d_multi_reduces/gemm_mean_meansquare_xdl_fp32.cpp b/example/16_gemm_multi_d_multi_reduces/gemm_mean_meansquare_xdl_fp32.cpp index c250b996928dd406bde9253ccbf40e6cd9d77347..0f5e588383abf30064d27e452131ceef1a7ad828 100644 --- a/example/16_gemm_multi_d_multi_reduces/gemm_mean_meansquare_xdl_fp32.cpp +++ b/example/16_gemm_multi_d_multi_reduces/gemm_mean_meansquare_xdl_fp32.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #include "gemm_reduce_xdl_common.hpp" diff --git a/example/16_gemm_multi_d_multi_reduces/gemm_reduce_xdl_common.hpp b/example/16_gemm_multi_d_multi_reduces/gemm_reduce_xdl_common.hpp index 62992de59765d3e05d935ac69894c4550a7ef472..1bea1bcf3e0fc1494e07952020db175a1d91bfbf 100644 --- a/example/16_gemm_multi_d_multi_reduces/gemm_reduce_xdl_common.hpp +++ b/example/16_gemm_multi_d_multi_reduces/gemm_reduce_xdl_common.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #include #include diff --git a/example/17_convnd_bwd_data/CMakeLists.txt b/example/17_convnd_bwd_data/CMakeLists.txt index ed95946c423f8813ff180596c80266728f0e012c..e187bd4337d53e9c1e7b2ecedf023169386825e3 100644 --- a/example/17_convnd_bwd_data/CMakeLists.txt +++ b/example/17_convnd_bwd_data/CMakeLists.txt @@ -1,6 +1,15 @@ -if(GPU_TARGETS MATCHES "gfx908" OR GPU_TARGETS MATCHES "gfx90a" OR GPU_TARGETS MATCHES "gfx940") +if(DTYPES MATCHES "fp16" OR NOT DEFINED DTYPES) +list(APPEND gpu_list gfx908 gfx90a gfx940 gfx941 gfx942) +set(target 0) +foreach(gpu IN LISTS GPU_TARGETS) + if(gpu IN_LIST gpu_list AND target EQUAL 0) add_example_executable(example_convnd_bwd_data_xdl_fp16 convnd_bwd_data_xdl_fp16.cpp) target_link_libraries(example_convnd_bwd_data_xdl_fp16 PRIVATE utility) + set(target 1) + endif() +endforeach() + if(DL_KERNELS) + add_example_executable(example_convnd_bwd_data_dl_fp16 convnd_bwd_data_dl_fp16.cpp) + target_link_libraries(example_convnd_bwd_data_dl_fp16 PRIVATE utility) + endif() endif() -add_example_executable(example_convnd_bwd_data_dl_fp16 convnd_bwd_data_dl_fp16.cpp) -target_link_libraries(example_convnd_bwd_data_dl_fp16 PRIVATE utility) diff --git a/example/17_convnd_bwd_data/convnd_bwd_data_common.hpp b/example/17_convnd_bwd_data/convnd_bwd_data_common.hpp index 26fa9e9821fc050c9ab69fe1a170b4c72a98dcd3..4a9d16c5c303e43989c1f32e51c2cbce6f279e5d 100644 --- a/example/17_convnd_bwd_data/convnd_bwd_data_common.hpp +++ b/example/17_convnd_bwd_data/convnd_bwd_data_common.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #include #include @@ -81,32 +81,33 @@ int run_conv_bwd_data(bool do_verification, in_device_buf.SetZero(); // do GEMM - auto conv = DeviceConvNdBwdDataInstance{}; - auto invoker = conv.MakeInvoker(); - auto argument = conv.MakeArgument(static_cast(in_device_buf.GetDeviceBuffer()), - static_cast(wei_device_buf.GetDeviceBuffer()), - static_cast(out_device_buf.GetDeviceBuffer()), - conv_param.N_, - conv_param.K_, - conv_param.C_, - conv_param.input_spatial_lengths_, - conv_param.filter_spatial_lengths_, - conv_param.GetOutputSpatialLengths(), - conv_param.conv_filter_strides_, - conv_param.conv_filter_dilations_, - conv_param.input_left_pads_, - conv_param.input_right_pads_, - in_element_op, - wei_element_op, - out_element_op); - - if(!conv.IsSupportedArgument(argument)) + auto conv = DeviceConvNdBwdDataInstance{}; + auto invoker = conv.MakeInvoker(); + auto argument = + conv.MakeArgumentPointer(static_cast(in_device_buf.GetDeviceBuffer()), + static_cast(wei_device_buf.GetDeviceBuffer()), + static_cast(out_device_buf.GetDeviceBuffer()), + conv_param.N_, + conv_param.K_, + conv_param.C_, + conv_param.input_spatial_lengths_, + conv_param.filter_spatial_lengths_, + conv_param.GetOutputSpatialLengths(), + conv_param.conv_filter_strides_, + conv_param.conv_filter_dilations_, + conv_param.input_left_pads_, + conv_param.input_right_pads_, + in_element_op, + wei_element_op, + out_element_op); + + if(!conv.IsSupportedArgument(argument.get())) { std::cout << "Not support,please check parameters or device"; return 0; } - float ave_time = invoker.Run(argument, StreamConfig{nullptr, time_kernel}); + float ave_time = invoker.Run(argument.get(), StreamConfig{nullptr, time_kernel}); std::size_t flop = conv_param.GetFlops(); std::size_t num_btype = conv_param.GetByte(); diff --git a/example/17_convnd_bwd_data/convnd_bwd_data_dl_fp16.cpp b/example/17_convnd_bwd_data/convnd_bwd_data_dl_fp16.cpp index f0896e977144867421b857afd0a58a239716712b..6b84eaba471a966c89b62d9f3be32d9b0e0834d6 100644 --- a/example/17_convnd_bwd_data/convnd_bwd_data_dl_fp16.cpp +++ b/example/17_convnd_bwd_data/convnd_bwd_data_dl_fp16.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #include "convnd_bwd_data_common.hpp" diff --git a/example/17_convnd_bwd_data/convnd_bwd_data_xdl_fp16.cpp b/example/17_convnd_bwd_data/convnd_bwd_data_xdl_fp16.cpp index c4f2c1f02bb9010ab3f5d97166819732e911b68a..c9989c60ac2c74e6821bdf6dd40e703998accf01 100644 --- a/example/17_convnd_bwd_data/convnd_bwd_data_xdl_fp16.cpp +++ b/example/17_convnd_bwd_data/convnd_bwd_data_xdl_fp16.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #include "convnd_bwd_data_common.hpp" diff --git a/example/18_batched_gemm_reduce/CMakeLists.txt b/example/18_batched_gemm_reduce/CMakeLists.txt index 0c3648dbfdd338b01dd6e77b093233f444303990..a1bb398af0640001f7fba2192e47ffc03e5ce3a0 100644 --- a/example/18_batched_gemm_reduce/CMakeLists.txt +++ b/example/18_batched_gemm_reduce/CMakeLists.txt @@ -1,4 +1,10 @@ -if(GPU_TARGETS MATCHES "gfx908" OR GPU_TARGETS MATCHES "gfx90a" OR GPU_TARGETS MATCHES "gfx940") +if(DTYPES MATCHES "fp16" OR NOT DEFINED DTYPES) +list(APPEND gpu_list gfx908 gfx90a gfx940 gfx941 gfx942) +set(target 0) +foreach(gpu IN LISTS GPU_TARGETS) + if(gpu IN_LIST gpu_list AND target EQUAL 0) add_example_executable(example_batched_gemm_reduce_xdl_fp16 batched_gemm_reduce_xdl_fp16.cpp) + set(target 1) + endif() +endforeach() endif() - diff --git a/example/18_batched_gemm_reduce/batched_gemm_reduce_xdl_fp16.cpp b/example/18_batched_gemm_reduce/batched_gemm_reduce_xdl_fp16.cpp index c2e3602a7bb22e3f31c66bb6045171229584fa14..e363dc5c12dd84a36a67e969b06d29a179679f94 100644 --- a/example/18_batched_gemm_reduce/batched_gemm_reduce_xdl_fp16.cpp +++ b/example/18_batched_gemm_reduce/batched_gemm_reduce_xdl_fp16.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #include #include diff --git a/example/19_binary_elementwise/broadcast_add_2d_amn_bn.cpp b/example/19_binary_elementwise/broadcast_add_2d_amn_bn.cpp index bee5dea546f6da3cc23aa539583915c9b8a3b21e..24c8d82f674d97afae8cb1b3dc0274c1c13daf35 100644 --- a/example/19_binary_elementwise/broadcast_add_2d_amn_bn.cpp +++ b/example/19_binary_elementwise/broadcast_add_2d_amn_bn.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #include #include diff --git a/example/19_binary_elementwise/broadcast_add_3d_am_bmnk.cpp b/example/19_binary_elementwise/broadcast_add_3d_am_bmnk.cpp index 6fc63b899e5c6abc91cad26c92c7d7733f0e8f5a..3c04c561403d0ecfb0819aa569bdf24511053442 100644 --- a/example/19_binary_elementwise/broadcast_add_3d_am_bmnk.cpp +++ b/example/19_binary_elementwise/broadcast_add_3d_am_bmnk.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #include #include diff --git a/example/19_binary_elementwise/elementwise_add_1d.cpp b/example/19_binary_elementwise/elementwise_add_1d.cpp index a5a6bc0a8bedee0a16906381c91a9c93b070e572..1ac09641a1e20a6be8bf44f4f1ff0f3be045dedb 100644 --- a/example/19_binary_elementwise/elementwise_add_1d.cpp +++ b/example/19_binary_elementwise/elementwise_add_1d.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #include #include diff --git a/example/19_binary_elementwise/elementwise_add_4d.cpp b/example/19_binary_elementwise/elementwise_add_4d.cpp index cc209b12e3df12413beabfd11961671f4e303b41..e571aa8468008a5f4e7097eee74763a583faf9af 100644 --- a/example/19_binary_elementwise/elementwise_add_4d.cpp +++ b/example/19_binary_elementwise/elementwise_add_4d.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #include #include diff --git a/example/20_grouped_conv_bwd_weight/CMakeLists.txt b/example/20_grouped_conv_bwd_weight/CMakeLists.txt index 0ee39ac84edf4e34562f96fcd7123e29777e611c..d649567ed2c46ce05073d731a435fbf65c034b73 100644 --- a/example/20_grouped_conv_bwd_weight/CMakeLists.txt +++ b/example/20_grouped_conv_bwd_weight/CMakeLists.txt @@ -1,15 +1,24 @@ -if(GPU_TARGETS MATCHES "gfx908" OR GPU_TARGETS MATCHES "gfx90a" OR GPU_TARGETS MATCHES "gfx940") +list(APPEND gpu_list gfx908 gfx90a gfx940 gfx941 gfx942) +set(target 0) +foreach(gpu IN LISTS GPU_TARGETS) + if(gpu IN_LIST gpu_list AND target EQUAL 0) add_custom_target(example_grouped_conv_bwd_weight) + if(DTYPES MATCHES "fp16" OR NOT DEFINED DTYPES) + add_example_executable(example_grouped_conv_bwd_weight_xdl_fp16 grouped_conv_bwd_weight_xdl_fp16.cpp) + add_dependencies(example_grouped_conv_bwd_weight example_grouped_conv_bwd_weight_xdl_fp16) + endif() + if(DTYPES MATCHES "bf16" OR NOT DEFINED DTYPES) + add_example_executable(example_grouped_conv_bwd_weight_xdl_bf16 grouped_conv_bwd_weight_xdl_bf16.cpp) + add_dependencies(example_grouped_conv_bwd_weight example_grouped_conv_bwd_weight_xdl_bf16) + endif() + set(target 1) + endif() +endforeach() - add_example_executable(example_grouped_conv_bwd_weight_xdl_fp16 grouped_conv_bwd_weight_xdl_fp16.cpp) - add_example_executable(example_grouped_conv_bwd_weight_xdl_bf16 grouped_conv_bwd_weight_xdl_bf16.cpp) - - add_dependencies(example_grouped_conv_bwd_weight example_grouped_conv_bwd_weight_xdl_fp16 - example_grouped_conv_bwd_weight_xdl_bf16) -endif() - -add_custom_target(example_grouped_conv_bwd_weight_dl) - -add_example_executable(example_grouped_conv_bwd_weight_dl_fp16 grouped_conv_bwd_weight_dl_fp16.cpp) - -add_dependencies(example_grouped_conv_bwd_weight_dl example_grouped_conv_bwd_weight_dl_fp16) +if(DTYPES MATCHES "fp16" OR NOT DEFINED DTYPES) + if(DL_KERNELS) + add_custom_target(example_grouped_conv_bwd_weight_dl) + add_example_executable(example_grouped_conv_bwd_weight_dl_fp16 grouped_conv_bwd_weight_dl_fp16.cpp) + add_dependencies(example_grouped_conv_bwd_weight_dl example_grouped_conv_bwd_weight_dl_fp16) + endif() +endif() \ No newline at end of file diff --git a/example/20_grouped_conv_bwd_weight/common.hpp b/example/20_grouped_conv_bwd_weight/common.hpp index 3f4818d2e3336799c3dfb5cca4fb14d9428ab58a..15727495f0f1f8dfedcb24a0ffa02d3d7aea67e0 100644 --- a/example/20_grouped_conv_bwd_weight/common.hpp +++ b/example/20_grouped_conv_bwd_weight/common.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/example/20_grouped_conv_bwd_weight/grouped_conv_bwd_weight_xdl_bf16.cpp b/example/20_grouped_conv_bwd_weight/grouped_conv_bwd_weight_xdl_bf16.cpp index aed6d22b023aafc35e60fac29b6375a8602b87d5..31e277b5c709a44a3495348c2e9419dafeb7611c 100644 --- a/example/20_grouped_conv_bwd_weight/grouped_conv_bwd_weight_xdl_bf16.cpp +++ b/example/20_grouped_conv_bwd_weight/grouped_conv_bwd_weight_xdl_bf16.cpp @@ -1,9 +1,9 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #include "common.hpp" -#include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_gnwc_gkxc_gnwk_xdl_cshuffle.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_xdl_cshuffle.hpp" using InDataType = BF16; // bf16 kernel use fp32 atomic add to accumulate Weight tensor into global memory @@ -17,8 +17,20 @@ using OutElementOp = PassThrough; template using DeviceConvBwdWeightInstance = - ck::tensor_operation::device::DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Xdl_CShuffle< - NDimSpatial, // NDimSpatial + ck::tensor_operation::device::DeviceGroupedConvBwdWeight_Xdl_CShuffle< + NDimSpatial, + ck::tuple_element_t>, + ck::tuple_element_t>, + ck::tuple_element_t>, InDataType, // InDataType WeiDataType, // WeiDataType OutDataType, // OutDataType diff --git a/example/20_grouped_conv_bwd_weight/grouped_conv_bwd_weight_xdl_fp16.cpp b/example/20_grouped_conv_bwd_weight/grouped_conv_bwd_weight_xdl_fp16.cpp index 4a2a6195d9574c26e3526d570bfddff067bd3e37..69c831cc54eef84fc0b7ea46f9961573bf3efff4 100644 --- a/example/20_grouped_conv_bwd_weight/grouped_conv_bwd_weight_xdl_fp16.cpp +++ b/example/20_grouped_conv_bwd_weight/grouped_conv_bwd_weight_xdl_fp16.cpp @@ -1,9 +1,9 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #include "common.hpp" -#include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_gnwc_gkxc_gnwk_xdl_cshuffle.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_xdl_cshuffle.hpp" using InDataType = F16; using WeiDataType = F16; @@ -16,8 +16,20 @@ using OutElementOp = PassThrough; template using DeviceConvBwdWeightInstance = - ck::tensor_operation::device::DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Xdl_CShuffle< - NDimSpatial, // NDimSpatial + ck::tensor_operation::device::DeviceGroupedConvBwdWeight_Xdl_CShuffle< + NDimSpatial, + ck::tuple_element_t>, + ck::tuple_element_t>, + ck::tuple_element_t>, InDataType, // InDataType WeiDataType, // WeiDataType OutDataType, // OutDataType diff --git a/example/20_grouped_conv_bwd_weight/run_grouped_conv_bwd_weight_example.inc b/example/20_grouped_conv_bwd_weight/run_grouped_conv_bwd_weight_example.inc index d46b37476697cec7121d26bdafdd4fb8bb4bb64e..29ce0324abecbbfbc53864871fa6cf558339bbc2 100644 --- a/example/20_grouped_conv_bwd_weight/run_grouped_conv_bwd_weight_example.inc +++ b/example/20_grouped_conv_bwd_weight/run_grouped_conv_bwd_weight_example.inc @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. template using HostConvBwdWeightInstance = ck::tensor_operation::host::ReferenceConvBwdWeight input_spatial_lengths{}; - std::array filter_spatial_lengths{}; - std::array output_spatial_lengths{}; + std::array input_lengths{}; + std::array input_strides{}; + std::array filter_lengths{}; + std::array weights_strides{}; + std::array output_lengths{}; + std::array output_strides{}; std::array conv_filter_strides{}; std::array conv_filter_dilations{}; std::array input_left_pads{}; @@ -82,9 +85,12 @@ bool run_grouped_conv_bwd_weight(const ExecutionConfig& config, auto range_copy = [](const auto& from, auto to) { std::copy(begin(from), end(from), to); }; - range_copy(conv_param.input_spatial_lengths_, begin(input_spatial_lengths)); - range_copy(conv_param.filter_spatial_lengths_, begin(filter_spatial_lengths)); - range_copy(conv_param.output_spatial_lengths_, begin(output_spatial_lengths)); + range_copy(in_g_n_c_wis_desc.GetLengths(), begin(input_lengths)); + range_copy(in_g_n_c_wis_desc.GetStrides(), begin(input_strides)); + range_copy(wei_g_k_c_xs_desc.GetLengths(), begin(filter_lengths)); + range_copy(wei_g_k_c_xs_desc.GetStrides(), begin(weights_strides)); + range_copy(out_g_n_k_wos_desc.GetLengths(), begin(output_lengths)); + range_copy(out_g_n_k_wos_desc.GetStrides(), begin(output_strides)); range_copy(conv_param.conv_filter_strides_, begin(conv_filter_strides)); range_copy(conv_param.conv_filter_dilations_, begin(conv_filter_dilations)); range_copy(conv_param.input_left_pads_, begin(input_left_pads)); @@ -96,13 +102,12 @@ bool run_grouped_conv_bwd_weight(const ExecutionConfig& config, auto argument = conv.MakeArgument(static_cast(in_device_buf.GetDeviceBuffer()), static_cast(wei_device_buf.GetDeviceBuffer()), static_cast(out_device_buf.GetDeviceBuffer()), - conv_param.G_, - conv_param.N_, - conv_param.K_, - conv_param.C_, - input_spatial_lengths, - filter_spatial_lengths, - output_spatial_lengths, + input_lengths, + input_strides, + filter_lengths, + weights_strides, + output_lengths, + output_strides, conv_filter_strides, conv_filter_dilations, input_left_pads, diff --git a/example/21_gemm_layernorm/CMakeLists.txt b/example/21_gemm_layernorm/CMakeLists.txt index 7f974221b7fb3b6a4599d1aeb5d380a054293c3e..6a6735efd626bc1931fa9f7e6651dceae713da11 100644 --- a/example/21_gemm_layernorm/CMakeLists.txt +++ b/example/21_gemm_layernorm/CMakeLists.txt @@ -1,6 +1,13 @@ -if(GPU_TARGETS MATCHES "gfx908" OR GPU_TARGETS MATCHES "gfx90a" OR GPU_TARGETS MATCHES "gfx940") +if(DTYPES MATCHES "fp16" OR NOT DEFINED DTYPES) +list(APPEND gpu_list gfx908 gfx90a gfx940 gfx941 gfx942) +set(target 0) +foreach(gpu IN LISTS GPU_TARGETS) + if(gpu IN_LIST gpu_list AND target EQUAL 0) add_example_executable(example_gemm_bias_relu_add_layernorm_xdl_welford_fp16 gemm_bias_relu_add_layernorm_xdl_welford_fp16.cpp) add_example_executable(example_gemm_bias_relu_add_layernorm_xdl_naive_fp16 gemm_bias_relu_add_layernorm_xdl_naive_fp16.cpp) add_example_executable(example_gemm_layernorm_xdl_naive_fp16 gemm_layernorm_xdl_naive_fp16.cpp) add_example_executable(example_gemm_xdl_layernorm_naive_single_kernel_fp16 gemm_xdl_layernorm_naive_single_kernel_fp16.cpp) + set(target 1) + endif() +endforeach() endif() diff --git a/example/21_gemm_layernorm/gemm_bias_relu_add_layernorm_xdl_naive_fp16.cpp b/example/21_gemm_layernorm/gemm_bias_relu_add_layernorm_xdl_naive_fp16.cpp index 192fe87b626ff721ca106b9c633c109483f5772a..96d04dcb37798fb6db3334bd2d802e435fdc93fd 100644 --- a/example/21_gemm_layernorm/gemm_bias_relu_add_layernorm_xdl_naive_fp16.cpp +++ b/example/21_gemm_layernorm/gemm_bias_relu_add_layernorm_xdl_naive_fp16.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #include #include diff --git a/example/21_gemm_layernorm/gemm_bias_relu_add_layernorm_xdl_welford_fp16.cpp b/example/21_gemm_layernorm/gemm_bias_relu_add_layernorm_xdl_welford_fp16.cpp index 3f01e6947728ecbef57734190b3896ff53899b02..fc58ca19f8673aa0b2205df21350f7ddf0b92196 100644 --- a/example/21_gemm_layernorm/gemm_bias_relu_add_layernorm_xdl_welford_fp16.cpp +++ b/example/21_gemm_layernorm/gemm_bias_relu_add_layernorm_xdl_welford_fp16.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #include #include diff --git a/example/21_gemm_layernorm/gemm_layernorm_xdl_naive_fp16.cpp b/example/21_gemm_layernorm/gemm_layernorm_xdl_naive_fp16.cpp index 4da6da65f7ab21b6ffae367747ccf5f0c71232dc..bd1d6932aceb72890dcb0f7c35004cc0dccd93af 100644 --- a/example/21_gemm_layernorm/gemm_layernorm_xdl_naive_fp16.cpp +++ b/example/21_gemm_layernorm/gemm_layernorm_xdl_naive_fp16.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #include #include diff --git a/example/21_gemm_layernorm/gemm_xdl_layernorm_naive_single_kernel_fp16.cpp b/example/21_gemm_layernorm/gemm_xdl_layernorm_naive_single_kernel_fp16.cpp index e7d857c4a0fa53a262243b759ddf9b33c691e26c..90d80f9f034b391f75c498f3a34232edf64f5260 100644 --- a/example/21_gemm_layernorm/gemm_xdl_layernorm_naive_single_kernel_fp16.cpp +++ b/example/21_gemm_layernorm/gemm_xdl_layernorm_naive_single_kernel_fp16.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #include #include diff --git a/example/22_cgemm/CMakeLists.txt b/example/22_cgemm/CMakeLists.txt index 156456115611ca12744e7773c199138733b88cd7..854f07fda6aad404fda546b40bef1c8b45c766d5 100644 --- a/example/22_cgemm/CMakeLists.txt +++ b/example/22_cgemm/CMakeLists.txt @@ -1,16 +1,21 @@ add_custom_target(example_cgemm_xdl) -add_example_executable(example_cgemm_xdl_bf16 cgemm_xdl_bf16.cpp) -add_example_executable(example_cgemm_xdl_fp16 cgemm_xdl_fp16.cpp) +if(DTYPES MATCHES "bf16" OR NOT DEFINED DTYPES) + add_example_executable(example_cgemm_xdl_bf16 cgemm_xdl_bf16.cpp) + add_dependencies(example_cgemm_xdl example_cgemm_xdl_bf16) +endif() +if(DTYPES MATCHES "fp16" OR NOT DEFINED DTYPES) + add_example_executable(example_cgemm_xdl_fp16 cgemm_xdl_fp16.cpp) + add_dependencies(example_cgemm_xdl example_cgemm_xdl_fp16) +endif() +if(DTYPES MATCHES "fp32" OR NOT DEFINED DTYPES) add_example_executable(example_cgemm_xdl_fp32 cgemm_xdl_fp32.cpp) -add_example_executable(example_cgemm_xdl_int8 cgemm_xdl_int8.cpp) - -add_dependencies(example_cgemm_xdl - example_cgemm_xdl_bf16 - example_cgemm_xdl_fp16 - example_cgemm_xdl_fp32 - example_cgemm_xdl_int8) - +add_dependencies(example_cgemm_xdl example_cgemm_xdl_fp32) +endif() +if(DTYPES MATCHES "int8" OR NOT DEFINED DTYPES) + add_example_executable(example_cgemm_xdl_int8 cgemm_xdl_int8.cpp) + add_dependencies(example_cgemm_xdl example_cgemm_xdl_int8) +endif() if(USE_BITINT_EXTENSION_INT4) add_example_executable(example_cgemm_xdl_int4 cgemm_xdl_int4.cpp) add_dependencies(example_cgemm_xdl example_cgemm_xdl_int4) diff --git a/example/22_cgemm/cgemm_xdl_bf16.cpp b/example/22_cgemm/cgemm_xdl_bf16.cpp index 92ed90ce4ab3ec09da2e053ceebe4d4ed6bcf36d..fa4482a984f20d203bd8bee68614c8e71ecbf83c 100644 --- a/example/22_cgemm/cgemm_xdl_bf16.cpp +++ b/example/22_cgemm/cgemm_xdl_bf16.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #include diff --git a/example/22_cgemm/cgemm_xdl_common.hpp b/example/22_cgemm/cgemm_xdl_common.hpp index 6aa06b7c32cb476b61a88235e13a12a3d1a15db5..26137a7c2e50d0d5010639592c0a614e2eca607b 100644 --- a/example/22_cgemm/cgemm_xdl_common.hpp +++ b/example/22_cgemm/cgemm_xdl_common.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #include #include diff --git a/example/22_cgemm/cgemm_xdl_fp16.cpp b/example/22_cgemm/cgemm_xdl_fp16.cpp index 11373736ee8b37efb5f4082253d46f02074f011f..89a581e865a56b231711352f9733403ed2945aea 100644 --- a/example/22_cgemm/cgemm_xdl_fp16.cpp +++ b/example/22_cgemm/cgemm_xdl_fp16.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #include diff --git a/example/22_cgemm/cgemm_xdl_fp32.cpp b/example/22_cgemm/cgemm_xdl_fp32.cpp index 0f45c18c4818726f179adda70dc12b5ea7c45b9d..cf9659959990823f10ef0cb30fa0eef928c32adc 100644 --- a/example/22_cgemm/cgemm_xdl_fp32.cpp +++ b/example/22_cgemm/cgemm_xdl_fp32.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #include diff --git a/example/22_cgemm/cgemm_xdl_int4.cpp b/example/22_cgemm/cgemm_xdl_int4.cpp index c26a83baafd375b477845fbdd793c3918d7b7dc5..f69cc2b3cc4f98e94ddbf0b37e8ac7324515a8d6 100644 --- a/example/22_cgemm/cgemm_xdl_int4.cpp +++ b/example/22_cgemm/cgemm_xdl_int4.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #include diff --git a/example/22_cgemm/cgemm_xdl_int8.cpp b/example/22_cgemm/cgemm_xdl_int8.cpp index 2f24189861d8468848d28b9c0480da7c3f2c4fdd..c4835b853ee75496f963a8e14d7a0190eb4cb1cf 100644 --- a/example/22_cgemm/cgemm_xdl_int8.cpp +++ b/example/22_cgemm/cgemm_xdl_int8.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #include diff --git a/example/23_softmax/softmax_blockwise.cpp b/example/23_softmax/softmax_blockwise.cpp index 41afd72f5ac9ddc3dc030a914e1a26da60c27356..d09e434bcfbb262a83c5167429d41d1ca54391ef 100644 --- a/example/23_softmax/softmax_blockwise.cpp +++ b/example/23_softmax/softmax_blockwise.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #include #include diff --git a/example/24_batched_gemm/CMakeLists.txt b/example/24_batched_gemm/CMakeLists.txt index 7962576e875dd6b7addd659dd9b9f3b6c1504a41..48a3b58ff53a712492e82ffce5632c08b6937ba2 100644 --- a/example/24_batched_gemm/CMakeLists.txt +++ b/example/24_batched_gemm/CMakeLists.txt @@ -1,16 +1,20 @@ add_custom_target(example_batched_gemm_xdl) - -add_example_executable(example_batched_gemm_xdl_fp32 batched_gemm_xdl_fp32.cpp) -add_example_executable(example_batched_gemm_xdl_fp16 batched_gemm_xdl_fp16.cpp) -add_example_executable(example_batched_gemm_xdl_bfp16 batched_gemm_xdl_bfp16.cpp) -add_example_executable(example_batched_gemm_xdl_int8 batched_gemm_xdl_int8.cpp) - -add_dependencies(example_batched_gemm_xdl - example_batched_gemm_xdl_fp32 - example_batched_gemm_xdl_fp16 - example_batched_gemm_xdl_bfp16 - example_batched_gemm_xdl_int8) - +if(DTYPES MATCHES "fp32" OR NOT DEFINED DTYPES) + add_example_executable(example_batched_gemm_xdl_fp32 batched_gemm_xdl_fp32.cpp) + add_dependencies(example_batched_gemm_xdl example_batched_gemm_xdl_fp32) +endif() +if(DTYPES MATCHES "fp16" OR NOT DEFINED DTYPES) + add_example_executable(example_batched_gemm_xdl_fp16 batched_gemm_xdl_fp16.cpp) + add_dependencies(example_batched_gemm_xdl example_batched_gemm_xdl_fp16) +endif() +if(DTYPES MATCHES "bf16" OR NOT DEFINED DTYPES) + add_example_executable(example_batched_gemm_xdl_bfp16 batched_gemm_xdl_bfp16.cpp) + add_dependencies(example_batched_gemm_xdl example_batched_gemm_xdl_bfp16) +endif() +if(DTYPES MATCHES "int8" OR NOT DEFINED DTYPES) + add_example_executable(example_batched_gemm_xdl_int8 batched_gemm_xdl_int8.cpp) + add_dependencies(example_batched_gemm_xdl example_batched_gemm_xdl_int8) +endif() if(USE_BITINT_EXTENSION_INT4) add_example_executable(example_batched_gemm_xdl_int4 batched_gemm_xdl_int4.cpp) add_dependencies(example_batched_gemm_xdl example_batched_gemm_xdl_int4) diff --git a/example/25_gemm_bias_e_permute/CMakeLists.txt b/example/25_gemm_bias_e_permute/CMakeLists.txt index cbc3c007bc22622554fd93f4d8a829c9ab666dc4..eb274b233802401628abe0f030159da75b8f5950 100644 --- a/example/25_gemm_bias_e_permute/CMakeLists.txt +++ b/example/25_gemm_bias_e_permute/CMakeLists.txt @@ -1,2 +1,4 @@ -add_example_executable(example_gemm_bias_e_permute_g1m3n2k1_xdl_fp16 gemm_bias_e_permute_g1m3n2k1_xdl_fp16.cpp) -add_example_executable(example_gemm_bias_e_permute_g1m2n3k1_xdl_fp16 gemm_bias_e_permute_g1m2n3k1_xdl_fp16.cpp) +if(DTYPES MATCHES "fp16" OR NOT DEFINED DTYPES) + add_example_executable(example_gemm_bias_e_permute_g1m3n2k1_xdl_fp16 gemm_bias_e_permute_g1m3n2k1_xdl_fp16.cpp) + add_example_executable(example_gemm_bias_e_permute_g1m2n3k1_xdl_fp16 gemm_bias_e_permute_g1m2n3k1_xdl_fp16.cpp) +endif() diff --git a/example/25_gemm_bias_e_permute/gemm_bias_e_permute_g1m2n3k1_xdl_fp16.cpp b/example/25_gemm_bias_e_permute/gemm_bias_e_permute_g1m2n3k1_xdl_fp16.cpp index c934d35019602e634da6f8b6d49c1c3b133c6800..420a7cf74f3186ac62d5dc37346a202178d8d273 100644 --- a/example/25_gemm_bias_e_permute/gemm_bias_e_permute_g1m2n3k1_xdl_fp16.cpp +++ b/example/25_gemm_bias_e_permute/gemm_bias_e_permute_g1m2n3k1_xdl_fp16.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #include #include diff --git a/example/25_gemm_bias_e_permute/gemm_bias_e_permute_g1m3n2k1_xdl_fp16.cpp b/example/25_gemm_bias_e_permute/gemm_bias_e_permute_g1m3n2k1_xdl_fp16.cpp index 98835f98fa6ecea9c38f1155f5dd03d0787fa0d2..9d606db205dde86d544971b4b7fc4830ba73c568 100644 --- a/example/25_gemm_bias_e_permute/gemm_bias_e_permute_g1m3n2k1_xdl_fp16.cpp +++ b/example/25_gemm_bias_e_permute/gemm_bias_e_permute_g1m3n2k1_xdl_fp16.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #include #include diff --git a/example/26_contraction/CMakeLists.txt b/example/26_contraction/CMakeLists.txt index c58751f0dd38e8fcc0df0da428777e7f3c7645b1..6cab88b13ef5761a16cedcd9d86d7f855fe5b99e 100644 --- a/example/26_contraction/CMakeLists.txt +++ b/example/26_contraction/CMakeLists.txt @@ -1,5 +1,8 @@ -add_example_executable(example_contraction_bilinear_xdl_fp32 contraction_bilinear_xdl_fp32.cpp) -add_example_executable(example_contraction_scale_xdl_fp32 contraction_scale_xdl_fp32.cpp) - -add_example_executable(example_contraction_bilinear_xdl_fp64 contraction_bilinear_xdl_fp64.cpp) -add_example_executable(example_contraction_scale_xdl_fp64 contraction_scale_xdl_fp64.cpp) +if(DTYPES MATCHES "fp32" OR NOT DEFINED DTYPES) + add_example_executable(example_contraction_bilinear_xdl_fp32 contraction_bilinear_xdl_fp32.cpp) + add_example_executable(example_contraction_scale_xdl_fp32 contraction_scale_xdl_fp32.cpp) +endif() +if(DTYPES MATCHES "fp64" OR NOT DEFINED DTYPES) + add_example_executable(example_contraction_bilinear_xdl_fp64 contraction_bilinear_xdl_fp64.cpp) + add_example_executable(example_contraction_scale_xdl_fp64 contraction_scale_xdl_fp64.cpp) +endif() diff --git a/example/26_contraction/contraction_bilinear_xdl_fp32.cpp b/example/26_contraction/contraction_bilinear_xdl_fp32.cpp index 6004db6e0e1b0c28ea195066dbefcfd4253fadff..78522160c85a1024ec510197e5ab1069e6f077cf 100644 --- a/example/26_contraction/contraction_bilinear_xdl_fp32.cpp +++ b/example/26_contraction/contraction_bilinear_xdl_fp32.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #include #include diff --git a/example/26_contraction/contraction_bilinear_xdl_fp64.cpp b/example/26_contraction/contraction_bilinear_xdl_fp64.cpp index 9576ce3f2c3d47ccf576c191e8eca1b0bd97ccd5..6cceed5bc11ed8eb79568a0397bbf7e1b93fd33d 100644 --- a/example/26_contraction/contraction_bilinear_xdl_fp64.cpp +++ b/example/26_contraction/contraction_bilinear_xdl_fp64.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #include #include diff --git a/example/26_contraction/contraction_scale_xdl_fp32.cpp b/example/26_contraction/contraction_scale_xdl_fp32.cpp index 3aa2a7ba93858fb950eb785bac740cdfafe2b52a..1574f5d18fb7f5f193cd4adb1c1a88616d7b1437 100644 --- a/example/26_contraction/contraction_scale_xdl_fp32.cpp +++ b/example/26_contraction/contraction_scale_xdl_fp32.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #include #include diff --git a/example/26_contraction/contraction_scale_xdl_fp64.cpp b/example/26_contraction/contraction_scale_xdl_fp64.cpp index cccf6505c6473567c0ee2f096e3a038e23385bca..3dacc708877d838dd5fb0113b5d560885b0551ae 100644 --- a/example/26_contraction/contraction_scale_xdl_fp64.cpp +++ b/example/26_contraction/contraction_scale_xdl_fp64.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #include #include diff --git a/example/27_layernorm/CMakeLists.txt b/example/27_layernorm/CMakeLists.txt index 94c23ce77499e0c110301a3d3ae82b0ec119a7c7..9cb2cd0766ea8c5aaa9154e87457763cc38372ab 100644 --- a/example/27_layernorm/CMakeLists.txt +++ b/example/27_layernorm/CMakeLists.txt @@ -1,2 +1,4 @@ -add_example_executable(example_layernorm_fp16 layernorm_fp16.cpp) -add_example_executable(example_layernorm_splitk_fp16 layernorm_splitk_fp16.cpp) +if(DTYPES MATCHES "fp16" OR NOT DEFINED DTYPES) + add_example_executable(example_layernorm_fp16 layernorm_fp16.cpp) + add_example_executable(example_layernorm_splitk_fp16 layernorm_splitk_fp16.cpp) +endif() diff --git a/example/27_layernorm/common.hpp b/example/27_layernorm/common.hpp index 8d833a3ae9e765d96ba03459aa8b1ba18e4823ba..62a71713df84351fea902cdcf3275786b60f5a23 100644 --- a/example/27_layernorm/common.hpp +++ b/example/27_layernorm/common.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/example/27_layernorm/layernorm_fp16.cpp b/example/27_layernorm/layernorm_fp16.cpp index c15ffabf509415a9bb0a702997819a7da3bb8fb4..bb8b954f0acaf2e9d748104a05561f9cb02fc9a0 100644 --- a/example/27_layernorm/layernorm_fp16.cpp +++ b/example/27_layernorm/layernorm_fp16.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #include "common.hpp" diff --git a/example/27_layernorm/layernorm_splitk_fp16.cpp b/example/27_layernorm/layernorm_splitk_fp16.cpp index 01ee7161eb05164b37b914d936488cca0145f851..e0378d028b343a6a70335cb9163370d1586ebf0a 100644 --- a/example/27_layernorm/layernorm_splitk_fp16.cpp +++ b/example/27_layernorm/layernorm_splitk_fp16.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #include "common.hpp" diff --git a/example/27_layernorm/run_layernorm_example.inc b/example/27_layernorm/run_layernorm_example.inc index 678d8df281684f19f92b892766292358e7c00310..95200b540aa9f9704c8fad785b352bdf779c4094 100644 --- a/example/27_layernorm/run_layernorm_example.inc +++ b/example/27_layernorm/run_layernorm_example.inc @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/example/28_grouped_gemm_bias_e_permute/CMakeLists.txt b/example/28_grouped_gemm_bias_e_permute/CMakeLists.txt index 44ab16894ce054a4b25b2efd12ebac9152b3bd58..2fda1f62a9a94c69bde8f08e8c81167b2951c4d6 100644 --- a/example/28_grouped_gemm_bias_e_permute/CMakeLists.txt +++ b/example/28_grouped_gemm_bias_e_permute/CMakeLists.txt @@ -1 +1,3 @@ -add_example_executable(example_grouped_gemm_bias_e_permute_xdl_fp16 grouped_gemm_bias_e_permute_xdl_fp16.cpp) +if(DTYPES MATCHES "fp16" OR NOT DEFINED DTYPES) + add_example_executable(example_grouped_gemm_bias_e_permute_xdl_fp16 grouped_gemm_bias_e_permute_xdl_fp16.cpp) +endif() diff --git a/example/28_grouped_gemm_bias_e_permute/grouped_gemm_bias_e_permute_xdl_fp16.cpp b/example/28_grouped_gemm_bias_e_permute/grouped_gemm_bias_e_permute_xdl_fp16.cpp index f8e6501eadaee4e5b4156309afebdf16e11f3e8a..24e9b1d9b7d569adf20caf1bd0f0cddafdbd256b 100644 --- a/example/28_grouped_gemm_bias_e_permute/grouped_gemm_bias_e_permute_xdl_fp16.cpp +++ b/example/28_grouped_gemm_bias_e_permute/grouped_gemm_bias_e_permute_xdl_fp16.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #include #include diff --git a/example/29_batched_gemm_bias_e_permute/CMakeLists.txt b/example/29_batched_gemm_bias_e_permute/CMakeLists.txt index 32a87dd200fc5f9048fc2eb1f56f343fe5586c64..09c3e6c6085c233aeac87374e6996cade4dac142 100644 --- a/example/29_batched_gemm_bias_e_permute/CMakeLists.txt +++ b/example/29_batched_gemm_bias_e_permute/CMakeLists.txt @@ -1,5 +1,7 @@ -add_example_executable(example_batched_gemm_bias_e_permute_xdl_fp16 batched_gemm_bias_e_permute_xdl_fp16.cpp) +if(DTYPES MATCHES "fp16" OR NOT DEFINED DTYPES) + add_example_executable(example_batched_gemm_bias_e_permute_xdl_fp16 batched_gemm_bias_e_permute_xdl_fp16.cpp) -if(GPU_TARGETS MATCHES "gfx1100" OR GPU_TARGETS MATCHES "gfx1101" OR GPU_TARGETS MATCHES "gfx1102") - add_example_executable(example_batched_gemm_bias_e_permute_wmma_fp16 batched_gemm_bias_e_permute_wmma_fp16.cpp) + if(GPU_TARGETS MATCHES "gfx1100" OR GPU_TARGETS MATCHES "gfx1101" OR GPU_TARGETS MATCHES "gfx1102") + add_example_executable(example_batched_gemm_bias_e_permute_wmma_fp16 batched_gemm_bias_e_permute_wmma_fp16.cpp) + endif() endif() diff --git a/example/29_batched_gemm_bias_e_permute/batched_gemm_bias_e_permute_wmma_fp16.cpp b/example/29_batched_gemm_bias_e_permute/batched_gemm_bias_e_permute_wmma_fp16.cpp index 30ad38a566309dcfacd52f100099d8a3bf02f77f..62233e535151e6b647e938de04b6225fe487b6c4 100644 --- a/example/29_batched_gemm_bias_e_permute/batched_gemm_bias_e_permute_wmma_fp16.cpp +++ b/example/29_batched_gemm_bias_e_permute/batched_gemm_bias_e_permute_wmma_fp16.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #include #include diff --git a/example/29_batched_gemm_bias_e_permute/batched_gemm_bias_e_permute_xdl_fp16.cpp b/example/29_batched_gemm_bias_e_permute/batched_gemm_bias_e_permute_xdl_fp16.cpp index 25d815b9cdfdddc29c699a1bc51dd5191a5246eb..08158bfc250642e032d20fad940bb8853c71d446 100644 --- a/example/29_batched_gemm_bias_e_permute/batched_gemm_bias_e_permute_xdl_fp16.cpp +++ b/example/29_batched_gemm_bias_e_permute/batched_gemm_bias_e_permute_xdl_fp16.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #include #include diff --git a/example/30_grouped_conv_fwd_multiple_d/CMakeLists.txt b/example/30_grouped_conv_fwd_multiple_d/CMakeLists.txt index 9780a64cc28ce8985e9a83cdf6c55bac87f592fc..e37413c09880de71b5bc10b15624fb98e75bf4ba 100644 --- a/example/30_grouped_conv_fwd_multiple_d/CMakeLists.txt +++ b/example/30_grouped_conv_fwd_multiple_d/CMakeLists.txt @@ -1,24 +1,46 @@ -if(GPU_TARGETS MATCHES "gfx908" OR GPU_TARGETS MATCHES "gfx90a" OR GPU_TARGETS MATCHES "gfx940") - add_custom_target(example_grouped_conv_fwd_multiple_d) - - add_example_executable(example_grouped_conv_fwd_bias_relu_add_xdl_fp16 grouped_conv_fwd_bias_relu_add_xdl_fp16.cpp) - add_example_executable(example_grouped_conv_fwd_bias_relu_add_xdl_fp32 grouped_conv_fwd_bias_relu_add_xdl_fp32.cpp) - add_example_executable(example_grouped_conv_fwd_bias_relu_add_xdl_bf16 grouped_conv_fwd_bias_relu_add_xdl_bf16.cpp) - add_example_executable(example_grouped_conv_fwd_bias_relu_add_xdl_int8 grouped_conv_fwd_bias_relu_add_xdl_int8.cpp) - - add_dependencies(example_grouped_conv_fwd_multiple_d example_grouped_conv_fwd_bias_relu_add_xdl_fp16) - add_dependencies(example_grouped_conv_fwd_multiple_d example_grouped_conv_fwd_bias_relu_add_xdl_fp32) - add_dependencies(example_grouped_conv_fwd_multiple_d example_grouped_conv_fwd_bias_relu_add_xdl_bf16) - add_dependencies(example_grouped_conv_fwd_multiple_d example_grouped_conv_fwd_bias_relu_add_xdl_int8) +list(APPEND gpu_list1 gfx908 gfx90a gfx940 gfx941 gfx942) +list(APPEND gpu_list2 gfx1100 gfx1101 gfx1102) +set(target 0) +foreach(gpu IN LISTS GPU_TARGETS) + if(gpu IN_LIST gpu_list1 AND target EQUAL 0) + add_custom_target(example_grouped_conv_fwd_multiple_d) + if(DTYPES MATCHES "fp16" OR NOT DEFINED DTYPES) + add_example_executable(example_grouped_conv_fwd_bias_relu_add_xdl_fp16 grouped_conv_fwd_bias_relu_add_xdl_fp16.cpp) + add_dependencies(example_grouped_conv_fwd_multiple_d example_grouped_conv_fwd_bias_relu_add_xdl_fp16) + add_example_executable(example_grouped_conv_fwd_xdl_fp16 grouped_conv_fwd_xdl_fp16.cpp) + add_dependencies(example_grouped_conv_fwd_multiple_d example_grouped_conv_fwd_xdl_fp16) + endif() + if(DTYPES MATCHES "fp32" OR NOT DEFINED DTYPES) + add_example_executable(example_grouped_conv_fwd_bias_relu_add_xdl_fp32 grouped_conv_fwd_bias_relu_add_xdl_fp32.cpp) + add_dependencies(example_grouped_conv_fwd_multiple_d example_grouped_conv_fwd_bias_relu_add_xdl_fp32) + endif() + if(DTYPES MATCHES "bf16" OR NOT DEFINED DTYPES) + add_example_executable(example_grouped_conv_fwd_bias_relu_add_xdl_bf16 grouped_conv_fwd_bias_relu_add_xdl_bf16.cpp) + add_dependencies(example_grouped_conv_fwd_multiple_d example_grouped_conv_fwd_bias_relu_add_xdl_bf16) + endif() + if(DTYPES MATCHES "int8" OR NOT DEFINED DTYPES) + add_example_executable(example_grouped_conv_fwd_bias_relu_add_xdl_int8 grouped_conv_fwd_bias_relu_add_xdl_int8.cpp) + add_dependencies(example_grouped_conv_fwd_multiple_d example_grouped_conv_fwd_bias_relu_add_xdl_int8) + endif() if(USE_BITINT_EXTENSION_INT4) add_example_executable(example_grouped_conv_fwd_bias_relu_add_xdl_int4 grouped_conv_fwd_bias_relu_add_xdl_int4.cpp) add_dependencies(example_grouped_conv_fwd_multiple_d example_grouped_conv_fwd_bias_relu_add_xdl_int4) endif() # USE_BITINT_EXTENSION_INT4 - add_example_executable(example_grouped_conv_fwd_xdl_fp16 grouped_conv_fwd_xdl_fp16.cpp) - add_dependencies(example_grouped_conv_fwd_multiple_d example_grouped_conv_fwd_xdl_fp16) -endif() -if(GPU_TARGETS MATCHES "gfx1100" OR GPU_TARGETS MATCHES "gfx1101" OR GPU_TARGETS MATCHES "gfx1102") - add_example_executable(example_grouped_conv_fwd_bias_relu_add_wmma_fp16 grouped_conv_fwd_bias_relu_add_wmma_fp16.cpp) - add_example_executable(example_grouped_conv_fwd_bias_relu_add_wmma_int8 grouped_conv_fwd_bias_relu_add_wmma_int8.cpp) -endif() + + set(target 1) + endif() +endforeach() + +set(target 0) +foreach(gpu IN LISTS GPU_TARGETS) + if(gpu IN_LIST gpu_list2 AND target EQUAL 0) + if(DTYPES MATCHES "fp16" OR NOT DEFINED DTYPES) + add_example_executable(example_grouped_conv_fwd_bias_relu_add_wmma_fp16 grouped_conv_fwd_bias_relu_add_wmma_fp16.cpp) + endif() + if(DTYPES MATCHES "int8" OR NOT DEFINED DTYPES) + add_example_executable(example_grouped_conv_fwd_bias_relu_add_wmma_int8 grouped_conv_fwd_bias_relu_add_wmma_int8.cpp) + endif() + set(target 1) + endif() +endforeach() diff --git a/example/30_grouped_conv_fwd_multiple_d/common.hpp b/example/30_grouped_conv_fwd_multiple_d/common.hpp index e7c6ed9b939abb0e8f593d54d5b45c499db5f2c0..e60ebee6e4fe2e08db432816b75bbbd4c388946c 100644 --- a/example/30_grouped_conv_fwd_multiple_d/common.hpp +++ b/example/30_grouped_conv_fwd_multiple_d/common.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/example/30_grouped_conv_fwd_multiple_d/common_wmma.hpp b/example/30_grouped_conv_fwd_multiple_d/common_wmma.hpp index eb6975a6d812aef66f1caf493dfe82549a0e4ae9..ae769ff1d38306f4feed8d3cb00dde79b3f561eb 100644 --- a/example/30_grouped_conv_fwd_multiple_d/common_wmma.hpp +++ b/example/30_grouped_conv_fwd_multiple_d/common_wmma.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/example/30_grouped_conv_fwd_multiple_d/grouped_conv_fwd_bias_relu_add_wmma_fp16.cpp b/example/30_grouped_conv_fwd_multiple_d/grouped_conv_fwd_bias_relu_add_wmma_fp16.cpp index 9d1d257a2889dcdaefb697faa6bcf6c172c55839..039d25029921491e7e67808e554b8cb3e6eb4745 100644 --- a/example/30_grouped_conv_fwd_multiple_d/grouped_conv_fwd_bias_relu_add_wmma_fp16.cpp +++ b/example/30_grouped_conv_fwd_multiple_d/grouped_conv_fwd_bias_relu_add_wmma_fp16.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #include "common_wmma.hpp" diff --git a/example/30_grouped_conv_fwd_multiple_d/grouped_conv_fwd_bias_relu_add_xdl_bf16.cpp b/example/30_grouped_conv_fwd_multiple_d/grouped_conv_fwd_bias_relu_add_xdl_bf16.cpp index ee300d073a28b44cf8177f64a7278d5e13ef3130..43c0d57dc2adcae69ecde327137eb204b71fa44d 100644 --- a/example/30_grouped_conv_fwd_multiple_d/grouped_conv_fwd_bias_relu_add_xdl_bf16.cpp +++ b/example/30_grouped_conv_fwd_multiple_d/grouped_conv_fwd_bias_relu_add_xdl_bf16.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #include "common.hpp" diff --git a/example/30_grouped_conv_fwd_multiple_d/grouped_conv_fwd_bias_relu_add_xdl_fp16.cpp b/example/30_grouped_conv_fwd_multiple_d/grouped_conv_fwd_bias_relu_add_xdl_fp16.cpp index 5a9df0b1e880c8f3bc954c07cd5de3639644c4e1..40b4132b358d8259146c2987472dc7e9d6c71091 100644 --- a/example/30_grouped_conv_fwd_multiple_d/grouped_conv_fwd_bias_relu_add_xdl_fp16.cpp +++ b/example/30_grouped_conv_fwd_multiple_d/grouped_conv_fwd_bias_relu_add_xdl_fp16.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #include "common.hpp" diff --git a/example/30_grouped_conv_fwd_multiple_d/grouped_conv_fwd_bias_relu_add_xdl_fp32.cpp b/example/30_grouped_conv_fwd_multiple_d/grouped_conv_fwd_bias_relu_add_xdl_fp32.cpp index c2906cc9dd1dbe22c8d420ce0929783fb0be0fa6..e05d384f26b40273a6bfe77222b805289a99e631 100644 --- a/example/30_grouped_conv_fwd_multiple_d/grouped_conv_fwd_bias_relu_add_xdl_fp32.cpp +++ b/example/30_grouped_conv_fwd_multiple_d/grouped_conv_fwd_bias_relu_add_xdl_fp32.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #include "common.hpp" diff --git a/example/30_grouped_conv_fwd_multiple_d/grouped_conv_fwd_bias_relu_add_xdl_int4.cpp b/example/30_grouped_conv_fwd_multiple_d/grouped_conv_fwd_bias_relu_add_xdl_int4.cpp index 3d5a243e6b9282f9d362c82c0758776341df646f..5494563fdd568d1f51ec9ee9042d094096e8fc8a 100644 --- a/example/30_grouped_conv_fwd_multiple_d/grouped_conv_fwd_bias_relu_add_xdl_int4.cpp +++ b/example/30_grouped_conv_fwd_multiple_d/grouped_conv_fwd_bias_relu_add_xdl_int4.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #ifndef CK_EXPERIMENTAL_BIT_INT_EXTENSION_INT4 #error Should compile this file with ck::int4_t support diff --git a/example/30_grouped_conv_fwd_multiple_d/grouped_conv_fwd_bias_relu_add_xdl_int8.cpp b/example/30_grouped_conv_fwd_multiple_d/grouped_conv_fwd_bias_relu_add_xdl_int8.cpp index eaf680fa438a14ee402e3805c18692ffdcf78c7e..6bf2e8d963c1524bfec0680409580b47e5394bc5 100644 --- a/example/30_grouped_conv_fwd_multiple_d/grouped_conv_fwd_bias_relu_add_xdl_int8.cpp +++ b/example/30_grouped_conv_fwd_multiple_d/grouped_conv_fwd_bias_relu_add_xdl_int8.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #include "common.hpp" diff --git a/example/30_grouped_conv_fwd_multiple_d/run_grouped_conv_fwd_bias_relu_add_example.inc b/example/30_grouped_conv_fwd_multiple_d/run_grouped_conv_fwd_bias_relu_add_example.inc index 4561156e0bd3a935c49751802e8c3037d8735228..eb242203eaa65ba7b8fca45c5f1e3e1b9d96c409 100644 --- a/example/30_grouped_conv_fwd_multiple_d/run_grouped_conv_fwd_bias_relu_add_example.inc +++ b/example/30_grouped_conv_fwd_multiple_d/run_grouped_conv_fwd_bias_relu_add_example.inc @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. template struct LayoutSetting diff --git a/example/30_grouped_conv_fwd_multiple_d/run_grouped_conv_fwd_bias_relu_add_wmma_example.inc b/example/30_grouped_conv_fwd_multiple_d/run_grouped_conv_fwd_bias_relu_add_wmma_example.inc index a6888649c0ed0adc6f2084c428b2c847cd952f77..360b2c8947b371b1502afb3af73c469f0d8c3636 100644 --- a/example/30_grouped_conv_fwd_multiple_d/run_grouped_conv_fwd_bias_relu_add_wmma_example.inc +++ b/example/30_grouped_conv_fwd_multiple_d/run_grouped_conv_fwd_bias_relu_add_wmma_example.inc @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. template struct LayoutSetting diff --git a/example/30_grouped_conv_fwd_multiple_d/run_grouped_conv_fwd_example.inc b/example/30_grouped_conv_fwd_multiple_d/run_grouped_conv_fwd_example.inc index d087c31af5def501635ef173769851e111dec3cb..58ed69182e0622751dc3fa2823fcb7cfa5b279a8 100644 --- a/example/30_grouped_conv_fwd_multiple_d/run_grouped_conv_fwd_example.inc +++ b/example/30_grouped_conv_fwd_multiple_d/run_grouped_conv_fwd_example.inc @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. template using DeviceConvFwdInstance = diff --git a/example/31_batched_gemm_gemm/CMakeLists.txt b/example/31_batched_gemm_gemm/CMakeLists.txt index dd9aef94a08b921cf6d60b3f6938281203a64f41..2074520f8c06d527be4ff1ea6a87c388cfbd6391 100644 --- a/example/31_batched_gemm_gemm/CMakeLists.txt +++ b/example/31_batched_gemm_gemm/CMakeLists.txt @@ -1,12 +1,26 @@ -if(GPU_TARGETS MATCHES "gfx908" OR GPU_TARGETS MATCHES "gfx90a" OR GPU_TARGETS MATCHES "gfx940") - add_example_executable(example_batched_gemm_gemm_xdl_fp32 batched_gemm_gemm_xdl_fp32.cpp) - add_example_executable(example_batched_gemm_gemm_xdl_fp16 batched_gemm_gemm_xdl_fp16.cpp) - add_example_executable(example_batched_gemm_gemm_xdl_bf16 batched_gemm_gemm_xdl_bf16.cpp) - if(NOT GPU_TARGETS MATCHES "gfx940") - add_example_executable(example_batched_gemm_gemm_xdl_int8 batched_gemm_gemm_xdl_int8.cpp) +list(APPEND gpu_list1 gfx908 gfx90a gfx940 gfx941 gfx942) +list(APPEND gpu_list2 gfx908 gfx90a) +set(target 0) +foreach(gpu IN LISTS GPU_TARGETS) + if(gpu IN_LIST gpu_list1 AND target EQUAL 0) + if(DTYPES MATCHES "fp32" OR NOT DEFINED DTYPES) + add_example_executable(example_batched_gemm_gemm_xdl_fp32 batched_gemm_gemm_xdl_fp32.cpp) + endif() + if(DTYPES MATCHES "fp16" OR NOT DEFINED DTYPES) + add_example_executable(example_batched_gemm_gemm_xdl_fp16 batched_gemm_gemm_xdl_fp16.cpp) + endif() + if(DTYPES MATCHES "bf16" OR NOT DEFINED DTYPES) + add_example_executable(example_batched_gemm_gemm_xdl_bf16 batched_gemm_gemm_xdl_bf16.cpp) endif() - if(USE_BITINT_EXTENSION_INT4) add_example_executable(example_batched_gemm_gemm_xdl_int4 batched_gemm_gemm_xdl_int4.cpp) endif(USE_BITINT_EXTENSION_INT4) -endif() \ No newline at end of file + set(target 1) + endif() +endforeach() + +if(NOT GPU_TARGETS MATCHES "gfx94" AND NOT GPU_TARGETS MATCHES "gfx1") + if(DTYPES MATCHES "int8" OR NOT DEFINED DTYPES) + add_example_executable(example_batched_gemm_gemm_xdl_int8 batched_gemm_gemm_xdl_int8.cpp) + endif() +endif() diff --git a/example/31_batched_gemm_gemm/batched_gemm_gemm_xdl_bf16.cpp b/example/31_batched_gemm_gemm/batched_gemm_gemm_xdl_bf16.cpp index 74e0e07e62a99e0700213e80a066d7e15128e516..7605d9c4f8368e369e4c29167f90620984b5a6c8 100644 --- a/example/31_batched_gemm_gemm/batched_gemm_gemm_xdl_bf16.cpp +++ b/example/31_batched_gemm_gemm/batched_gemm_gemm_xdl_bf16.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. /* Gemm + Gemm fused operation. Computes C_m_o = A_m_k * B0_k_n * B1_n_o diff --git a/example/31_batched_gemm_gemm/batched_gemm_gemm_xdl_fp16.cpp b/example/31_batched_gemm_gemm/batched_gemm_gemm_xdl_fp16.cpp index d5fadb8081e86a83784f6fad7e7405a7808b65b9..33ed04fb3068b1bf83eb128825ccad355c6de4f9 100644 --- a/example/31_batched_gemm_gemm/batched_gemm_gemm_xdl_fp16.cpp +++ b/example/31_batched_gemm_gemm/batched_gemm_gemm_xdl_fp16.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. /* Gemm + Gemm fused operation. Computes C_m_o = A_m_k * B0_k_n * B1_n_o diff --git a/example/31_batched_gemm_gemm/batched_gemm_gemm_xdl_fp32.cpp b/example/31_batched_gemm_gemm/batched_gemm_gemm_xdl_fp32.cpp index 0dd4e0914f479eefae604ceba8115432ce1d3803..e0eb193ad0484c62ac6fa5695ac43aa17c509227 100644 --- a/example/31_batched_gemm_gemm/batched_gemm_gemm_xdl_fp32.cpp +++ b/example/31_batched_gemm_gemm/batched_gemm_gemm_xdl_fp32.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. /* Gemm + Gemm fused operation. Computes C_m_o = A_m_k * B0_k_n * B1_n_o diff --git a/example/31_batched_gemm_gemm/batched_gemm_gemm_xdl_int4.cpp b/example/31_batched_gemm_gemm/batched_gemm_gemm_xdl_int4.cpp index 1fd93622a1b14a72d8159517a84b9f19d2959723..d166214c3376cd90afd11796a4cb85ea28421861 100644 --- a/example/31_batched_gemm_gemm/batched_gemm_gemm_xdl_int4.cpp +++ b/example/31_batched_gemm_gemm/batched_gemm_gemm_xdl_int4.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. /* Gemm + Gemm fused operation. Computes C_m_o = A_m_k * B0_k_n * B1_n_o diff --git a/example/31_batched_gemm_gemm/batched_gemm_gemm_xdl_int8.cpp b/example/31_batched_gemm_gemm/batched_gemm_gemm_xdl_int8.cpp index 15d98abab7dfa0f1168fd411eeae481054053000..40f87d1f554c4ef2a73ceb32263c3499f733b40a 100644 --- a/example/31_batched_gemm_gemm/batched_gemm_gemm_xdl_int8.cpp +++ b/example/31_batched_gemm_gemm/batched_gemm_gemm_xdl_int8.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. /* Gemm + Gemm fused operation. Computes C_m_o = A_m_k * B0_k_n * B1_n_o diff --git a/example/31_batched_gemm_gemm/run_batched_gemm_gemm_example.inc b/example/31_batched_gemm_gemm/run_batched_gemm_gemm_example.inc index 7e5f1614bcf6c6133d88f9d8dad63ec4be47c9ac..f329146728dd8df3c38af3d2520ac91d43d12d00 100644 --- a/example/31_batched_gemm_gemm/run_batched_gemm_gemm_example.inc +++ b/example/31_batched_gemm_gemm/run_batched_gemm_gemm_example.inc @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/example/32_batched_gemm_scale_softmax_gemm/CMakeLists.txt b/example/32_batched_gemm_scale_softmax_gemm/CMakeLists.txt index 8d9aaec85a503b060254ef748db77f4a7f289dae..0463bf6bd31483ada687f17ff175e0195404623c 100644 --- a/example/32_batched_gemm_scale_softmax_gemm/CMakeLists.txt +++ b/example/32_batched_gemm_scale_softmax_gemm/CMakeLists.txt @@ -1,16 +1,24 @@ -add_example_executable(example_batched_gemm_scale_softmax_gemm_xdl_fp16 batched_gemm_scale_softmax_gemm_xdl_fp16.cpp) -add_example_executable(example_batched_gemm_scale_softmax_gemm_xdl_bf16 batched_gemm_scale_softmax_gemm_xdl_bf16.cpp) -add_example_executable(example_batched_gemm_scale_softmax_gemm_permute_xdl_fp16 batched_gemm_scale_softmax_gemm_permute_xdl_fp16.cpp) -add_example_executable(example_batched_gemm_scale_softmax_gemm_permute_xdl_bf16 batched_gemm_scale_softmax_gemm_permute_xdl_bf16.cpp) -add_example_executable(example_grouped_gemm_scale_softmax_gemm_permute_xdl_fp16 grouped_gemm_scale_softmax_gemm_permute_xdl_fp16.cpp) -add_example_executable(example_batched_gemm_lower_triangle_scale_softmax_gemm_permute_xdl_fp16 batched_gemm_lower_triangle_scale_softmax_gemm_permute_xdl_fp16.cpp) -add_example_executable(example_grouped_gemm_lower_triangle_scale_softmax_gemm_permute_xdl_fp16 grouped_gemm_lower_triangle_scale_softmax_gemm_permute_xdl_fp16.cpp) +if(DTYPES MATCHES "fp16" OR NOT DEFINED DTYPES) + add_example_executable(example_batched_gemm_scale_softmax_gemm_xdl_fp16 batched_gemm_scale_softmax_gemm_xdl_fp16.cpp) + add_example_executable(example_batched_gemm_scale_softmax_gemm_permute_xdl_fp16 batched_gemm_scale_softmax_gemm_permute_xdl_fp16.cpp) + add_example_executable(example_grouped_gemm_scale_softmax_gemm_permute_xdl_fp16 grouped_gemm_scale_softmax_gemm_permute_xdl_fp16.cpp) + add_example_executable(example_batched_gemm_lower_triangle_scale_softmax_gemm_permute_xdl_fp16 batched_gemm_lower_triangle_scale_softmax_gemm_permute_xdl_fp16.cpp) + add_example_executable(example_grouped_gemm_lower_triangle_scale_softmax_gemm_permute_xdl_fp16 grouped_gemm_lower_triangle_scale_softmax_gemm_permute_xdl_fp16.cpp) +endif() +if(DTYPES MATCHES "bf16" OR NOT DEFINED DTYPES) + add_example_executable(example_batched_gemm_scale_softmax_gemm_xdl_bf16 batched_gemm_scale_softmax_gemm_xdl_bf16.cpp) + add_example_executable(example_batched_gemm_scale_softmax_gemm_permute_xdl_bf16 batched_gemm_scale_softmax_gemm_permute_xdl_bf16.cpp) +endif() add_custom_target(example_gemm_scale_softmax_gemm) -add_dependencies(example_gemm_scale_softmax_gemm example_batched_gemm_scale_softmax_gemm_xdl_fp16) -add_dependencies(example_gemm_scale_softmax_gemm example_batched_gemm_scale_softmax_gemm_xdl_bf16) -add_dependencies(example_gemm_scale_softmax_gemm example_batched_gemm_scale_softmax_gemm_permute_xdl_fp16) -add_dependencies(example_gemm_scale_softmax_gemm example_batched_gemm_scale_softmax_gemm_permute_xdl_bf16) -add_dependencies(example_gemm_scale_softmax_gemm example_grouped_gemm_scale_softmax_gemm_permute_xdl_fp16) -add_dependencies(example_gemm_scale_softmax_gemm example_batched_gemm_lower_triangle_scale_softmax_gemm_permute_xdl_fp16) -add_dependencies(example_gemm_scale_softmax_gemm example_grouped_gemm_lower_triangle_scale_softmax_gemm_permute_xdl_fp16) +if(DTYPES MATCHES "fp16" OR NOT DEFINED DTYPES) + add_dependencies(example_gemm_scale_softmax_gemm example_batched_gemm_scale_softmax_gemm_xdl_fp16) + add_dependencies(example_gemm_scale_softmax_gemm example_batched_gemm_scale_softmax_gemm_permute_xdl_fp16) + add_dependencies(example_gemm_scale_softmax_gemm example_grouped_gemm_scale_softmax_gemm_permute_xdl_fp16) + add_dependencies(example_gemm_scale_softmax_gemm example_batched_gemm_lower_triangle_scale_softmax_gemm_permute_xdl_fp16) + add_dependencies(example_gemm_scale_softmax_gemm example_grouped_gemm_lower_triangle_scale_softmax_gemm_permute_xdl_fp16) +endif() +if(DTYPES MATCHES "bf16" OR NOT DEFINED DTYPES) + add_dependencies(example_gemm_scale_softmax_gemm example_batched_gemm_scale_softmax_gemm_xdl_bf16) + add_dependencies(example_gemm_scale_softmax_gemm example_batched_gemm_scale_softmax_gemm_permute_xdl_bf16) +endif() diff --git a/example/32_batched_gemm_scale_softmax_gemm/batched_gemm_lower_triangle_scale_softmax_gemm_permute_xdl_fp16.cpp b/example/32_batched_gemm_scale_softmax_gemm/batched_gemm_lower_triangle_scale_softmax_gemm_permute_xdl_fp16.cpp index 0eb15653306f0179b1d132ddfcd130af6ba13fd6..1d1566d57561afaa84be2de34b11416b5da0571c 100644 --- a/example/32_batched_gemm_scale_softmax_gemm/batched_gemm_lower_triangle_scale_softmax_gemm_permute_xdl_fp16.cpp +++ b/example/32_batched_gemm_scale_softmax_gemm/batched_gemm_lower_triangle_scale_softmax_gemm_permute_xdl_fp16.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. /* Gemm + Softmax + Gemm fused operation. Computes C_g_m_o = Softmax(A_g_m_k * B0_g_k_n) * B1_g_n_o diff --git a/example/32_batched_gemm_scale_softmax_gemm/batched_gemm_scale_softmax_gemm_permute_xdl_bf16.cpp b/example/32_batched_gemm_scale_softmax_gemm/batched_gemm_scale_softmax_gemm_permute_xdl_bf16.cpp index 8f1db577c604418b7b1d546b5de251f5a20054a6..bae88d4b8e6f786219e8d68bd0a34f6036dd0d72 100644 --- a/example/32_batched_gemm_scale_softmax_gemm/batched_gemm_scale_softmax_gemm_permute_xdl_bf16.cpp +++ b/example/32_batched_gemm_scale_softmax_gemm/batched_gemm_scale_softmax_gemm_permute_xdl_bf16.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. /* Gemm + Softmax + Gemm fused operation. Computes C_g_m_o = Softmax(A_g_m_k * B0_g_k_n) * B1_g_n_o diff --git a/example/32_batched_gemm_scale_softmax_gemm/batched_gemm_scale_softmax_gemm_permute_xdl_fp16.cpp b/example/32_batched_gemm_scale_softmax_gemm/batched_gemm_scale_softmax_gemm_permute_xdl_fp16.cpp index 2ce91a8c6023314d61b354614c81ca74c5f95727..a098ce6675e05308fd9de223e807a0e14504a7f5 100644 --- a/example/32_batched_gemm_scale_softmax_gemm/batched_gemm_scale_softmax_gemm_permute_xdl_fp16.cpp +++ b/example/32_batched_gemm_scale_softmax_gemm/batched_gemm_scale_softmax_gemm_permute_xdl_fp16.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. /* Gemm + Softmax + Gemm fused operation. Computes C_g_m_o = Softmax(A_g_m_k * B0_g_k_n) * B1_g_n_o diff --git a/example/32_batched_gemm_scale_softmax_gemm/batched_gemm_scale_softmax_gemm_xdl_bf16.cpp b/example/32_batched_gemm_scale_softmax_gemm/batched_gemm_scale_softmax_gemm_xdl_bf16.cpp index 1fd2bf69306f5b35dd4ea2ae484c677c6b3c63fc..ce8caf758842225098baac92876ea6996321ff21 100644 --- a/example/32_batched_gemm_scale_softmax_gemm/batched_gemm_scale_softmax_gemm_xdl_bf16.cpp +++ b/example/32_batched_gemm_scale_softmax_gemm/batched_gemm_scale_softmax_gemm_xdl_bf16.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. /* Gemm + Softmax + Gemm fused operation. Computes C_g_m_o = Softmax(A_g_m_k * B0_g_k_n) * B1_g_n_o diff --git a/example/32_batched_gemm_scale_softmax_gemm/batched_gemm_scale_softmax_gemm_xdl_fp16.cpp b/example/32_batched_gemm_scale_softmax_gemm/batched_gemm_scale_softmax_gemm_xdl_fp16.cpp index f4a8589052f0f149902553884a58cc4c6c79d030..138db14963809cd49294d1cfc2171b038dd10c60 100644 --- a/example/32_batched_gemm_scale_softmax_gemm/batched_gemm_scale_softmax_gemm_xdl_fp16.cpp +++ b/example/32_batched_gemm_scale_softmax_gemm/batched_gemm_scale_softmax_gemm_xdl_fp16.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. /* Gemm + Softmax + Gemm fused operation. Computes C_g_m_o = Softmax(A_g_m_k * B0_g_k_n) * B1_g_n_o diff --git a/example/32_batched_gemm_scale_softmax_gemm/grouped_gemm_lower_triangle_scale_softmax_gemm_permute_xdl_fp16.cpp b/example/32_batched_gemm_scale_softmax_gemm/grouped_gemm_lower_triangle_scale_softmax_gemm_permute_xdl_fp16.cpp index e4a71b04313446f75f4cbc73456a66bdbb62c47e..57949242941d7f750d1b121712f1d5cca6cee4d5 100644 --- a/example/32_batched_gemm_scale_softmax_gemm/grouped_gemm_lower_triangle_scale_softmax_gemm_permute_xdl_fp16.cpp +++ b/example/32_batched_gemm_scale_softmax_gemm/grouped_gemm_lower_triangle_scale_softmax_gemm_permute_xdl_fp16.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. /* Gemm + Softmax + Gemm fused operation. Computes C_g_m_o = Softmax(A_g_m_k * B0_g_k_n) * B1_g_n_o @@ -17,7 +17,7 @@ Gemm + Softmax + Gemm fused operation. Computes C_g_m_o = Softmax(A_g_m_k * B0_g #include "ck/ck.hpp" #include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" #include "ck/tensor_operation/gpu/device/tensor_specialization.hpp" -#include "ck/tensor_operation/gpu/device/device_grouped_gemm_softmax_gemm_permute_xdl_cshuffle.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_grouped_gemm_softmax_gemm_permute_xdl_cshuffle.hpp" #include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" #include "ck/library/utility/check_err.hpp" diff --git a/example/32_batched_gemm_scale_softmax_gemm/grouped_gemm_scale_softmax_gemm_permute_xdl_fp16.cpp b/example/32_batched_gemm_scale_softmax_gemm/grouped_gemm_scale_softmax_gemm_permute_xdl_fp16.cpp index 38b5badc6e4f270e409ed98c492035c5bebe00c3..97caec6053bb7846dd53f37b26c3945a55812967 100644 --- a/example/32_batched_gemm_scale_softmax_gemm/grouped_gemm_scale_softmax_gemm_permute_xdl_fp16.cpp +++ b/example/32_batched_gemm_scale_softmax_gemm/grouped_gemm_scale_softmax_gemm_permute_xdl_fp16.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. /* Gemm + Softmax + Gemm fused operation. Computes C_g_m_o = Softmax(A_g_m_k * B0_g_k_n) * B1_g_n_o @@ -17,7 +17,7 @@ Gemm + Softmax + Gemm fused operation. Computes C_g_m_o = Softmax(A_g_m_k * B0_g #include "ck/ck.hpp" #include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" #include "ck/tensor_operation/gpu/device/tensor_specialization.hpp" -#include "ck/tensor_operation/gpu/device/device_grouped_gemm_softmax_gemm_permute_xdl_cshuffle.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_grouped_gemm_softmax_gemm_permute_xdl_cshuffle.hpp" #include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" #include "ck/library/utility/check_err.hpp" diff --git a/example/32_batched_gemm_scale_softmax_gemm/run_batched_gemm_scale_softmax_gemm.inc b/example/32_batched_gemm_scale_softmax_gemm/run_batched_gemm_scale_softmax_gemm.inc index 4e43dbdd8fc5a83486cd052ee3e789e05152861f..27602e2313f7aa197e88e1fabeb39245e2fdf5eb 100644 --- a/example/32_batched_gemm_scale_softmax_gemm/run_batched_gemm_scale_softmax_gemm.inc +++ b/example/32_batched_gemm_scale_softmax_gemm/run_batched_gemm_scale_softmax_gemm.inc @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. int run(int argc, char* argv[]) { diff --git a/example/32_batched_gemm_scale_softmax_gemm/run_batched_gemm_scale_softmax_gemm_permute.inc b/example/32_batched_gemm_scale_softmax_gemm/run_batched_gemm_scale_softmax_gemm_permute.inc index 0b876af952f65f89e9ae55bea344bfa6f61afd5d..fa76faea84e4551ddf8d0617c132dbe7a6045fb3 100644 --- a/example/32_batched_gemm_scale_softmax_gemm/run_batched_gemm_scale_softmax_gemm_permute.inc +++ b/example/32_batched_gemm_scale_softmax_gemm/run_batched_gemm_scale_softmax_gemm_permute.inc @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. int run(int argc, char* argv[]) { diff --git a/example/32_batched_gemm_scale_softmax_gemm/run_grouped_gemm_scale_softmax_gemm_permute.inc b/example/32_batched_gemm_scale_softmax_gemm/run_grouped_gemm_scale_softmax_gemm_permute.inc index ef2acf61f55fa7f108e174d4586c13ae3dc7419d..ea1e2734a684b61a363eb93ef0e2ff933f900ea5 100644 --- a/example/32_batched_gemm_scale_softmax_gemm/run_grouped_gemm_scale_softmax_gemm_permute.inc +++ b/example/32_batched_gemm_scale_softmax_gemm/run_grouped_gemm_scale_softmax_gemm_permute.inc @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. int run(int argc, char* argv[]) { diff --git a/example/33_multiple_reduce/dual_reduce_common.hpp b/example/33_multiple_reduce/dual_reduce_common.hpp index 326606752b25394d504cb054cf3026173c66e970..cd21790be6548abd0f8097613ee650d4407fc400 100644 --- a/example/33_multiple_reduce/dual_reduce_common.hpp +++ b/example/33_multiple_reduce/dual_reduce_common.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #include #include diff --git a/example/33_multiple_reduce/dual_reduce_multiblock.cpp b/example/33_multiple_reduce/dual_reduce_multiblock.cpp index 9360599ed9e8a7517bc244e2942d6949c1f71e61..198931749b19b8e8d91bf16c1a733801cb27aeba 100644 --- a/example/33_multiple_reduce/dual_reduce_multiblock.cpp +++ b/example/33_multiple_reduce/dual_reduce_multiblock.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #include #include diff --git a/example/33_multiple_reduce/dual_reduce_threadwise.cpp b/example/33_multiple_reduce/dual_reduce_threadwise.cpp index 56255839e567fa0386ba09b80738f5d0d2abcc23..7609edad3527e8219c2ef6351381d6e53455a67b 100644 --- a/example/33_multiple_reduce/dual_reduce_threadwise.cpp +++ b/example/33_multiple_reduce/dual_reduce_threadwise.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #include #include diff --git a/example/34_batchnorm/CMakeLists.txt b/example/34_batchnorm/CMakeLists.txt index d964f40d877ddb5dc6e03cf8d28897e244514d96..60824c5f4d7d742941721085be6fc22d638dffe9 100644 --- a/example/34_batchnorm/CMakeLists.txt +++ b/example/34_batchnorm/CMakeLists.txt @@ -1,3 +1,4 @@ add_example_executable(example_batchnorm_forward_training batchnorm_forward_training_nhwc.cpp) +add_example_executable(example_batchnorm_forward_training_obsolete batchnorm_forward_training_nhwc_obsolete.cpp) add_example_executable(example_batchnorm_forward_inferring batchnorm_forward_inferring_nhwc.cpp) add_example_executable(example_batchnorm_backward batchnorm_backward_nhwc.cpp) diff --git a/example/34_batchnorm/batchnorm_backward_nhwc.cpp b/example/34_batchnorm/batchnorm_backward_nhwc.cpp index a6ca9d150bd918966bae06a63ee4eb9da6be5f3f..3756310fd7dd47b111b9fde743cac844ae2c7669 100644 --- a/example/34_batchnorm/batchnorm_backward_nhwc.cpp +++ b/example/34_batchnorm/batchnorm_backward_nhwc.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #include #include diff --git a/example/34_batchnorm/batchnorm_common.hpp b/example/34_batchnorm/batchnorm_common.hpp index bdbc8ea8b88f40de2f25a2ec8c5a74ab5e38fd74..a1b8d253bf061c02511945b98eac9343c0d1a236 100644 --- a/example/34_batchnorm/batchnorm_common.hpp +++ b/example/34_batchnorm/batchnorm_common.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/example/34_batchnorm/batchnorm_forward_inferring_nhwc.cpp b/example/34_batchnorm/batchnorm_forward_inferring_nhwc.cpp index dc2984851a02e0020c8c4cfa6d59e61bfe49c593..6a8002025a60b277c587190a95855cc262c33e17 100644 --- a/example/34_batchnorm/batchnorm_forward_inferring_nhwc.cpp +++ b/example/34_batchnorm/batchnorm_forward_inferring_nhwc.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #include #include diff --git a/example/34_batchnorm/batchnorm_forward_training_nhwc.cpp b/example/34_batchnorm/batchnorm_forward_training_nhwc.cpp index da36d65a2954ee1ac3a94c617a219e6a7c2baf44..b27358fd9de4547dc8fa8532981f097713eb1e07 100644 --- a/example/34_batchnorm/batchnorm_forward_training_nhwc.cpp +++ b/example/34_batchnorm/batchnorm_forward_training_nhwc.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #include #include @@ -414,7 +414,7 @@ bool bnorm_fwd_nhwc_test(bool do_verification, (void)invoker_ptr_ref->Run(argument_ptr_ref.get()); y_dev.FromDevice(y.mData.data()); - pass = pass && ck::utils::check_err(y, y_ref); + pass = pass && ck::utils::check_err(y, y_ref, "Incorrect normalized output values"); if(updateMovingAverage) { @@ -424,8 +424,12 @@ bool bnorm_fwd_nhwc_test(bool do_verification, resultRunningMean_dev.FromDevice(resultRunningMean.mData.data()); resultRunningVariance_dev.FromDevice(resultRunningVariance.mData.data()); - pass = pass && ck::utils::check_err(resultRunningMean, resultRunningMean_ref); - pass = pass && ck::utils::check_err(resultRunningVariance, resultRunningVariance_ref); + pass = pass && ck::utils::check_err(resultRunningMean, + resultRunningMean_ref, + "Incorrect running mean values"); + pass = pass && ck::utils::check_err(resultRunningVariance, + resultRunningVariance_ref, + "Incorrect running variance values"); }; if(saveMeanAndInvVariance) @@ -438,8 +442,11 @@ bool bnorm_fwd_nhwc_test(bool do_verification, resultSaveMean_dev.FromDevice(resultSaveMean.mData.data()); resultSaveInvVariance_dev.FromDevice(resultSaveInvVariance.mData.data()); - pass = pass && ck::utils::check_err(resultSaveMean, resultSaveMean_ref); - pass = pass && ck::utils::check_err(resultSaveInvVariance, resultSaveInvVariance_ref); + pass = pass && ck::utils::check_err( + resultSaveMean, resultSaveMean_ref, "Incorrect saved mean values"); + pass = pass && ck::utils::check_err(resultSaveInvVariance, + resultSaveInvVariance_ref, + "Incorrect saved invvariance values"); }; }; diff --git a/example/34_batchnorm/batchnorm_forward_training_nhwc_obsolete.cpp b/example/34_batchnorm/batchnorm_forward_training_nhwc_obsolete.cpp new file mode 100644 index 0000000000000000000000000000000000000000..ffb9f4b58468699c1a34ff6aff101c2dd4adaf66 --- /dev/null +++ b/example/34_batchnorm/batchnorm_forward_training_nhwc_obsolete.cpp @@ -0,0 +1,598 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#include +#include +#include +#include +#include +#include + +#include "ck/ck.hpp" +#include "ck/library/utility/algorithm.hpp" +#include "ck/library/utility/check_err.hpp" +#include "ck/library/utility/device_memory.hpp" +#include "ck/library/utility/host_tensor.hpp" +#include "ck/library/utility/host_tensor_generator.hpp" +#include "ck/library/utility/host_common_util.hpp" +#include "ck/library/reference_tensor_operation/cpu/reference_batchnorm_forward.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_batchnorm_forward_impl_obsolete.hpp" +#include "ck/library/utility/host_common_util.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" + +static struct option long_options[] = {{"inOutLengths", required_argument, nullptr, 'D'}, + {"verify", required_argument, nullptr, 'v'}, + {"help", no_argument, nullptr, '?'}, + {nullptr, 0, nullptr, 0}}; + +class BatchNormFwdArg +{ + private: + int option_index = 0; + + public: + std::vector inOutLengths; + + bool do_verification = false; + + bool updateMovingAverage; + bool saveMeanAndInvVariance; + + int data_type = 0; + int init_method = 2; + bool time_kernel = false; + bool use_multiblock_welford = false; + + public: + void show_usage(const char* cmd) + { + std::cout << "Usage of " << cmd << std::endl; + std::cout << "--inOutLengths or -D, comma separated list of input tensor dimension " + "lengths, must have 4 integers for nhwc" + << std::endl; + std::cout << "--verify or -v, 1/0 to indicate whether to verify the batch-normalization " + "result by " + "comparing with the host-based batch-normalization" + << std::endl; + std::cout << "Arg1: data type (0: fp16, 1: fp32, 3: int8, 5: bp16, 6: fp64)" << std::endl; + std::cout << "Arg2: 1/0 to indicate whether to update the moving average and variance " + "(0=no, 1=yes)" + << std::endl; + std::cout << "Arg3: 1/0 to indicate whether to save the calculated mean and invVariance " + "(0=no, 1=yes)" + << std::endl; + std::cout << "Arg4: init method used for bnScale and bnBias (0=no init, 1=single integer " + "value, 2=scope integer " + "value, 3=decimal value)" + << std::endl; + std::cout << "Arg5: time kernel (0=no, 1=yes)" << std::endl; + std::cout << "Arg6: use multi-block welford (0=n0, 1=yes)" << std::endl; + }; + + int processArgs(int argc, char* argv[]) + { + using ck::host_common::getTypeValuesFromString; + + int ch; + + while(1) + { + ch = getopt_long(argc, argv, "D:v:", long_options, &option_index); + if(ch == -1) + break; + switch(ch) + { + case 'D': + if(!optarg) + throw std::runtime_error("Invalid option format!"); + + inOutLengths = getTypeValuesFromString(optarg); + + if(inOutLengths.size() != 4) + throw std::runtime_error( + "NHWC tensor layout should have 4 length values specified!"); + break; + case 'v': + if(!optarg) + throw std::runtime_error("Invalid option format!"); + + do_verification = static_cast(std::atoi(optarg)); + break; + case '?': + if(std::string(long_options[option_index].name) == "help") + { + show_usage(argv[0]); + return (-1); + }; + break; + default: show_usage(argv[0]); return (-1); + }; + }; + + if(optind + 6 > argc) + throw std::runtime_error("Invalid cmd-line arguments, more argumetns are needed!"); + + data_type = std::atoi(argv[optind++]); + updateMovingAverage = std::atoi(argv[optind++]); + saveMeanAndInvVariance = std::atoi(argv[optind++]); + init_method = std::atoi(argv[optind++]); + time_kernel = static_cast(std::atoi(argv[optind++])); + use_multiblock_welford = static_cast(std::atoi(argv[optind])); + + if(data_type != 0 && data_type != 1 && data_type != 3 && data_type != 5 && data_type != 6) + return (-1); + + return (0); + }; +}; + +using namespace ck; + +template +bool bnorm_fwd_nhwc_test(bool do_verification, + int init_method, + bool time_kernel, + const std::vector inOutLengths, + bool updateMovingAverage, + bool saveMeanAndInvVariance, + double averageFactor, + double epsilon) +{ + // for NHWC BatchNorm calculation of mean and meansquare + constexpr int Rank = 4; + constexpr int NumReduceDim = 3; + + // when using lengths[] to create a tensor, lengths[0] is the length of highest dimension + // eg. N of NHWC, so lengths[3] is the dimension C length of NHWC + const std::vector scaleBiasMeanVarLengths = {inOutLengths[3]}; + + // input data of the batchnorm forward algorithm + Tensor x(inOutLengths); + Tensor bnScale(scaleBiasMeanVarLengths); + Tensor bnBias(scaleBiasMeanVarLengths); + + // output data of the batchnorm forward algorithm + Tensor y_ref(inOutLengths); + Tensor y(inOutLengths); + + Tensor resultSaveMean_ref(scaleBiasMeanVarLengths); + Tensor resultSaveInvVariance_ref(scaleBiasMeanVarLengths); + + Tensor resultRunningMean_ref(scaleBiasMeanVarLengths); + Tensor resultRunningVariance_ref(scaleBiasMeanVarLengths); + + auto inOutStrides = x.mDesc.GetStrides(); + auto scaleBiasMeanVarStrides = bnScale.mDesc.GetStrides(); + + std::size_t num_thread = std::thread::hardware_concurrency(); + + if(updateMovingAverage) + { + if constexpr(std::is_same::value) + { + x.GenerateTensorValue(GeneratorTensor_2{-5, 5}, num_thread); + + const float x_mean = 0.0f; + const float x_stddev = 2.5f; + const float noise_stddev = 0.04f; + + resultRunningMean_ref.GenerateTensorValue( + GeneratorTensor_4{x_mean, noise_stddev}, num_thread); + + resultRunningVariance_ref.GenerateTensorValue( + GeneratorTensor_4{x_stddev * x_stddev, noise_stddev}, num_thread); + } + else + { + const float x_mean = 0.0f; + const float x_stddev = 1.0f; + const float noise_stddev = 0.04f; + + // input data in normal distribution + x.GenerateTensorValue(GeneratorTensor_4{x_mean, x_stddev}, num_thread); + + // initialize the runningMean to be values with tiny variation to the mean of the x + // values + resultRunningMean_ref.GenerateTensorValue( + GeneratorTensor_4{x_mean, noise_stddev}, num_thread); + + // initialize the runningVariance to be values with tiny variation to the variance of + // the x values + resultRunningVariance_ref.GenerateTensorValue( + GeneratorTensor_4{x_stddev * x_stddev, noise_stddev}, num_thread); + }; + } + else + { + if constexpr(std::is_same::value) + x.GenerateTensorValue(GeneratorTensor_2{-5, 5}, num_thread); + else + x.GenerateTensorValue(GeneratorTensor_3{-5.0f, 5.0f}, num_thread); + }; + + if(do_verification) + { + switch(init_method) + { + case 0: + bnScale.GenerateTensorValue(GeneratorTensor_0{}, num_thread); + bnBias.GenerateTensorValue(GeneratorTensor_0{}, num_thread); + break; + case 1: + bnScale.GenerateTensorValue(GeneratorTensor_1{1}, num_thread); + bnBias.GenerateTensorValue(GeneratorTensor_1{0}, num_thread); + break; + case 2: + bnScale.GenerateTensorValue(GeneratorTensor_2{-5, 5}, num_thread); + bnBias.GenerateTensorValue(GeneratorTensor_2{-5, 5}, num_thread); + break; + default: + bnScale.GenerateTensorValue(GeneratorTensor_3{-5.0f, 5.0f}, num_thread); + bnBias.GenerateTensorValue(GeneratorTensor_3{-5.0f, 5.0f}, num_thread); + } + }; + + // these buffers are usually provided by the user application + DeviceMem x_dev(sizeof(InOutDataType) * x.mDesc.GetElementSpaceSize()); + DeviceMem y_dev(sizeof(InOutDataType) * y.mDesc.GetElementSpaceSize()); + DeviceMem bnScale_dev(sizeof(AccDataType) * bnScale.mDesc.GetElementSpaceSize()); + DeviceMem bnBias_dev(sizeof(AccDataType) * bnBias.mDesc.GetElementSpaceSize()); + + // mean_dev or resultSaveMean_dev + DeviceMem resultSaveMean_dev(sizeof(AccDataType) * + resultSaveMean_ref.mDesc.GetElementSpaceSize()); + // meansquare_dev or resultSaveInvVariance_dev + DeviceMem resultSaveInvVariance_dev(sizeof(AccDataType) * + resultSaveInvVariance_ref.mDesc.GetElementSpaceSize()); + // resultRunningMean_dev + DeviceMem resultRunningMean_dev(sizeof(AccDataType) * + resultRunningMean_ref.mDesc.GetElementSpaceSize()); + // resultRunningVariance_dev + DeviceMem resultRunningVariance_dev(sizeof(AccDataType) * + resultRunningVariance_ref.mDesc.GetElementSpaceSize()); + + x_dev.ToDevice(x.mData.data()); + bnScale_dev.ToDevice(bnScale.mData.data()); + bnBias_dev.ToDevice(bnBias.mData.data()); + + if(updateMovingAverage) + { + resultRunningMean_dev.ToDevice(resultRunningMean_ref.mData.data()); + resultRunningVariance_dev.ToDevice(resultRunningVariance_ref.mData.data()); + }; + + std::array i_inOutLengths; + std::array i_inOutStrides; + std::array i_scaleBiasMeanVarLengths; + std::array i_scaleBiasMeanVarStrides; + + ck::ranges::copy(inOutLengths, i_inOutLengths.begin()); + ck::ranges::copy(inOutStrides, i_inOutStrides.begin()); + ck::ranges::copy(scaleBiasMeanVarLengths, i_scaleBiasMeanVarLengths.begin()); + ck::ranges::copy(scaleBiasMeanVarStrides, i_scaleBiasMeanVarStrides.begin()); + + using PassThroughOp = ck::tensor_operation::element_wise::PassThrough; + + using DeviceBatchNormFwdInstance = + ck::tensor_operation::device::DeviceBatchNormFwdImpl; + + auto batchnorm_fwd = DeviceBatchNormFwdInstance{}; + + auto argument_ptr = batchnorm_fwd.MakeArgumentPointer( + i_inOutLengths, + i_inOutStrides, + i_inOutStrides, + {0, 1, 2}, // indicates physical indices of reduce dimensions in lengths[] and strides[] + i_scaleBiasMeanVarLengths, + i_scaleBiasMeanVarStrides, + i_scaleBiasMeanVarStrides, + i_scaleBiasMeanVarStrides, + x_dev.GetDeviceBuffer(), + bnScale_dev.GetDeviceBuffer(), + bnBias_dev.GetDeviceBuffer(), + epsilon, + PassThroughOp{}, + y_dev.GetDeviceBuffer(), + saveMeanAndInvVariance ? resultSaveMean_dev.GetDeviceBuffer() : nullptr, + saveMeanAndInvVariance ? resultSaveInvVariance_dev.GetDeviceBuffer() : nullptr, + averageFactor, + updateMovingAverage ? resultRunningMean_dev.GetDeviceBuffer() : nullptr, + updateMovingAverage ? resultRunningVariance_dev.GetDeviceBuffer() : nullptr); + + if(!batchnorm_fwd.IsSupportedArgument(argument_ptr.get())) + { + std::cout << "The runtime parameters seems not supported by the BatchNorm device instance, " + "exiting!" + << std::endl; + return (false); + }; + + size_t workspace_sz = batchnorm_fwd.GetWorkSpaceSize(argument_ptr.get()); + + DeviceMem workspace_dev(workspace_sz); + + batchnorm_fwd.SetWorkSpacePointer(argument_ptr.get(), workspace_dev.GetDeviceBuffer()); + + auto invoker_ptr = batchnorm_fwd.MakeInvokerPointer(); + + if(time_kernel) + { + float avg_time = 0.0f; + size_t num_bytes = 0; + + size_t total_length = inOutLengths[0] * inOutLengths[1] * inOutLengths[2] * inOutLengths[3]; + size_t invariant_length = inOutLengths[3]; + + avg_time = invoker_ptr->Run(argument_ptr.get(), StreamConfig{nullptr, time_kernel}); + + // inputing of x, scale, bias, outputing of y + num_bytes += + total_length * sizeof(InOutDataType) * 2 + invariant_length * sizeof(AccDataType) * 2; + + // outputing of mean, inv-variance + num_bytes += saveMeanAndInvVariance ? invariant_length * sizeof(AccDataType) * 2 : 0; + + // updating of moving mean, variance + num_bytes += updateMovingAverage ? invariant_length * sizeof(AccDataType) * 4 : 0; + + float gb_per_sec = num_bytes / 1.E6 / avg_time; + + std::cout << "Perf: " << avg_time << " ms, " << gb_per_sec << " GB/s" << std::endl; + } + else + (void)invoker_ptr->Run(argument_ptr.get(), StreamConfig{nullptr, time_kernel}); + + bool pass = true; + + if(do_verification) + { + + using ReferenceBatchNormFwdInstance = + ck::tensor_operation::host::ReferenceBatchNormFwd; + + auto batchNormFwd_ref = ReferenceBatchNormFwdInstance{}; + + auto argument_ptr_ref = batchNormFwd_ref.MakeArgumentPointer( + i_inOutLengths, + i_inOutStrides, + i_inOutStrides, + {0, 1, 2}, // indicates physical indices of reduce dimensions in lengths[] and strides[] + i_scaleBiasMeanVarLengths, + i_scaleBiasMeanVarStrides, + i_scaleBiasMeanVarStrides, + i_scaleBiasMeanVarStrides, + x.mData.data(), + bnScale.mData.data(), + bnBias.mData.data(), + epsilon, + PassThroughOp{}, + y_ref.mData.data(), + saveMeanAndInvVariance ? resultSaveMean_ref.mData.data() : nullptr, + saveMeanAndInvVariance ? resultSaveInvVariance_ref.mData.data() : nullptr, + averageFactor, + updateMovingAverage ? resultRunningMean_ref.mData.data() : nullptr, + updateMovingAverage ? resultRunningVariance_ref.mData.data() : nullptr); + + if(!batchNormFwd_ref.IsSupportedArgument(argument_ptr_ref.get())) + { + std::cout << "The runtime parameters seems not supported by the BatchNorm reference " + "instance, exiting!" + << std::endl; + return (false); + }; + + auto invoker_ptr_ref = batchNormFwd_ref.MakeInvokerPointer(); + + (void)invoker_ptr_ref->Run(argument_ptr_ref.get()); + + y_dev.FromDevice(y.mData.data()); + pass = pass && ck::utils::check_err(y, y_ref, "Incorrect normalized output values"); + + if(updateMovingAverage) + { + Tensor resultRunningMean(scaleBiasMeanVarLengths); + Tensor resultRunningVariance(scaleBiasMeanVarLengths); + + resultRunningMean_dev.FromDevice(resultRunningMean.mData.data()); + resultRunningVariance_dev.FromDevice(resultRunningVariance.mData.data()); + + pass = pass && ck::utils::check_err(resultRunningMean, + resultRunningMean_ref, + "Incorrect running mean values"); + pass = pass && ck::utils::check_err(resultRunningVariance, + resultRunningVariance_ref, + "Incorrect running variance values"); + }; + + if(saveMeanAndInvVariance) + { + using ck::host_common::dumpBufferToFile; + + Tensor resultSaveMean(scaleBiasMeanVarLengths); + Tensor resultSaveInvVariance(scaleBiasMeanVarLengths); + + resultSaveMean_dev.FromDevice(resultSaveMean.mData.data()); + resultSaveInvVariance_dev.FromDevice(resultSaveInvVariance.mData.data()); + + pass = pass && ck::utils::check_err( + resultSaveMean, resultSaveMean_ref, "Incorrect saved mean values"); + pass = pass && ck::utils::check_err(resultSaveInvVariance, + resultSaveInvVariance_ref, + "Incorrect saved invvariance values"); + }; + }; + + return (pass); +}; + +const double epsilon = std::numeric_limits::epsilon(); +static const double averageFactor = 0.1; + +int main(int argc, char* argv[]) +{ + bool pass = true; + + if(argc > 1) + { + BatchNormFwdArg arg; + + if(arg.processArgs(argc, argv) < 0) + return (-1); + + if(arg.data_type == 0) + { + if(arg.use_multiblock_welford) + pass = bnorm_fwd_nhwc_test(arg.do_verification, + arg.init_method, + arg.time_kernel, + arg.inOutLengths, + arg.updateMovingAverage, + arg.saveMeanAndInvVariance, + averageFactor, + epsilon); + else + pass = bnorm_fwd_nhwc_test(arg.do_verification, + arg.init_method, + arg.time_kernel, + arg.inOutLengths, + arg.updateMovingAverage, + arg.saveMeanAndInvVariance, + averageFactor, + epsilon); + } + else if(arg.data_type == 1) + { + if(arg.use_multiblock_welford) + pass = bnorm_fwd_nhwc_test(arg.do_verification, + arg.init_method, + arg.time_kernel, + arg.inOutLengths, + arg.updateMovingAverage, + arg.saveMeanAndInvVariance, + averageFactor, + epsilon); + else + pass = bnorm_fwd_nhwc_test(arg.do_verification, + arg.init_method, + arg.time_kernel, + arg.inOutLengths, + arg.updateMovingAverage, + arg.saveMeanAndInvVariance, + averageFactor, + epsilon); + } + else if(arg.data_type == 3) + { + if(arg.use_multiblock_welford) + pass = bnorm_fwd_nhwc_test(arg.do_verification, + arg.init_method, + arg.time_kernel, + arg.inOutLengths, + arg.updateMovingAverage, + arg.saveMeanAndInvVariance, + averageFactor, + epsilon); + else + pass = bnorm_fwd_nhwc_test(arg.do_verification, + arg.init_method, + arg.time_kernel, + arg.inOutLengths, + arg.updateMovingAverage, + arg.saveMeanAndInvVariance, + averageFactor, + epsilon); + } + else if(arg.data_type == 5) + { + if(arg.use_multiblock_welford) + pass = bnorm_fwd_nhwc_test(arg.do_verification, + arg.init_method, + arg.time_kernel, + arg.inOutLengths, + arg.updateMovingAverage, + arg.saveMeanAndInvVariance, + averageFactor, + epsilon); + else + pass = bnorm_fwd_nhwc_test(arg.do_verification, + arg.init_method, + arg.time_kernel, + arg.inOutLengths, + arg.updateMovingAverage, + arg.saveMeanAndInvVariance, + averageFactor, + epsilon); + } + else if(arg.data_type == 6) + { + if(arg.use_multiblock_welford) + pass = bnorm_fwd_nhwc_test(arg.do_verification, + arg.init_method, + arg.time_kernel, + arg.inOutLengths, + arg.updateMovingAverage, + arg.saveMeanAndInvVariance, + averageFactor, + epsilon); + else + pass = bnorm_fwd_nhwc_test(arg.do_verification, + arg.init_method, + arg.time_kernel, + arg.inOutLengths, + arg.updateMovingAverage, + arg.saveMeanAndInvVariance, + averageFactor, + epsilon); + } + } + else + { + pass = bnorm_fwd_nhwc_test(true, + 2, + false, // don't time kernel + {128, 16, 6, 512}, + true, + true, + averageFactor, + epsilon); + + pass = pass && bnorm_fwd_nhwc_test(true, + 2, + false, // don't time kernel + {128, 16, 3, 1024}, + true, + true, + averageFactor, + epsilon); + }; + + return (pass ? 0 : 1); +} diff --git a/example/34_batchnorm/batchnorm_infer_impl.hpp b/example/34_batchnorm/batchnorm_infer_impl.hpp index 15170586b636825325c796cdc20eb70996a24bb7..d0b545b2a31d2fc095eece12ce273db823a5d2ac 100644 --- a/example/34_batchnorm/batchnorm_infer_impl.hpp +++ b/example/34_batchnorm/batchnorm_infer_impl.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/example/35_splitK_gemm/CMakeLists.txt b/example/35_splitK_gemm/CMakeLists.txt index f5a6ccb2414d9b7c3ca4753c93018a280a243f08..251a9b93c57e3b24ae9a1305bbe7607eb696817f 100644 --- a/example/35_splitK_gemm/CMakeLists.txt +++ b/example/35_splitK_gemm/CMakeLists.txt @@ -1,18 +1,28 @@ -if(GPU_TARGETS MATCHES "gfx908" OR GPU_TARGETS MATCHES "gfx90a" OR GPU_TARGETS MATCHES "gfx940") +list(APPEND gpu_list gfx908 gfx90a gfx940 gfx941 gfx942) +set(target 0) +foreach(gpu IN LISTS GPU_TARGETS) + if(gpu IN_LIST gpu_list AND target EQUAL 0) add_custom_target(example_splitK_gemm_xdl) - add_example_executable(example_splitK_gemm_xdl_fp32 splitK_gemm_xdl_fp32.cpp) - add_example_executable(example_splitK_gemm_xdl_fp16 splitK_gemm_xdl_fp16.cpp) - add_example_executable(example_splitK_gemm_xdl_bfp16 splitK_gemm_xdl_bfp16.cpp) - add_example_executable(example_splitK_gemm_xdl_int8 splitK_gemm_xdl_int8.cpp) - - add_dependencies(example_splitK_gemm_xdl - example_splitK_gemm_xdl_fp32 - example_splitK_gemm_xdl_fp16 - example_splitK_gemm_xdl_bfp16 - example_splitK_gemm_xdl_int8) - + if(DTYPES MATCHES "fp32" OR NOT DEFINED DTYPES) + add_example_executable(example_splitK_gemm_xdl_fp32 splitK_gemm_xdl_fp32.cpp) + add_dependencies(example_splitK_gemm_xdl example_splitK_gemm_xdl_fp32) + endif() + if(DTYPES MATCHES "fp16" OR NOT DEFINED DTYPES) + add_example_executable(example_splitK_gemm_xdl_fp16 splitK_gemm_xdl_fp16.cpp) + add_dependencies(example_splitK_gemm_xdl example_splitK_gemm_xdl_fp16) + endif() + if(DTYPES MATCHES "bf16" OR NOT DEFINED DTYPES) + add_example_executable(example_splitK_gemm_xdl_bfp16 splitK_gemm_xdl_bfp16.cpp) + add_dependencies(example_splitK_gemm_xdl example_splitK_gemm_xdl_bfp16) + endif() + if(DTYPES MATCHES "int8" OR NOT DEFINED DTYPES) + add_example_executable(example_splitK_gemm_xdl_int8 splitK_gemm_xdl_int8.cpp) + add_dependencies(example_splitK_gemm_xdl example_splitK_gemm_xdl_int8) + endif() if(USE_BITINT_EXTENSION_INT4) add_example_executable(example_splitK_gemm_xdl_int4 splitK_gemm_xdl_int4.cpp) add_dependencies(example_splitK_gemm_xdl example_splitK_gemm_xdl_int4) endif() -endif() + set(target 1) + endif() +endforeach() diff --git a/example/35_splitK_gemm/splitK_gemm_xdl_bfp16.cpp b/example/35_splitK_gemm/splitK_gemm_xdl_bfp16.cpp index 7191ecf50ab5252b37855072b4e03f1194df52f6..fdf49a31b719c2aa0e7d6441e52c4dd1475d4f95 100644 --- a/example/35_splitK_gemm/splitK_gemm_xdl_bfp16.cpp +++ b/example/35_splitK_gemm/splitK_gemm_xdl_bfp16.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #include #include @@ -33,6 +33,7 @@ using ADataType = BF16; using BDataType = BF16; using AccDataType = F32; using CDataType = F32; +using ComputeType = BF16; using ALayout = Row; using BLayout = Col; @@ -46,11 +47,11 @@ static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecializa using DeviceGemmInstance = ck::tensor_operation::device::DeviceGemmXdlSplitKCShuffle // clang-format off -//######| AData| BData| CData| AccData| ALayout| BLayout| CLayout| A| B| C| GEMM| Block| MPer| NPer| KPer| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| -//######| Type| Type| Type| Type| | | | Elementwise| Elementwise| Elementwise| Spacialization| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MXdlPerWave_MWaveMPerXdl| ScalarPerVector| -//######| | | | | | | | Operation| Operation| Operation| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NXdlPerWave_NWaveNPerXdl| _NWaveNPerXdl| -//######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | - < ADataType, BDataType, CDataType, AccDataType, ALayout, BLayout, CLayout, AElementOp, BElementOp, CElementOp, GemmDefault, 256, 256, 128, 4, 8, 32, 32, 4, 2, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 64, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 3, 8, 8, true, 1, 1, S<1, 32, 1, 8>, 4>; +//######| AData| BData| CData| AccData| ALayout| BLayout| CLayout| A| B| C| GEMM| Block| MPer| NPer| KPer| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| Compute| +//######| Type| Type| Type| Type| | | | Elementwise| Elementwise| Elementwise| Spacialization| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MXdlPerWave_MWaveMPerXdl| ScalarPerVector| Type| +//######| | | | | | | | Operation| Operation| Operation| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NXdlPerWave_NWaveNPerXdl| _NWaveNPerXdl| | +//######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + < ADataType, BDataType, CDataType, AccDataType, ALayout, BLayout, CLayout, AElementOp, BElementOp, CElementOp, GemmDefault, 256, 256, 128, 4, 8, 32, 32, 4, 2, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 64, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 3, 8, 8, true, 1, 1, S<1, 32, 1, 8>, 4, ComputeType>; // clang-format on #include "run_splitK_gemm_example.inc" diff --git a/example/35_splitK_gemm/splitK_gemm_xdl_fp16.cpp b/example/35_splitK_gemm/splitK_gemm_xdl_fp16.cpp index efdb315b4e5576aef8d992f6d57c231ac1f83a81..74fb16e15b019e145cdf13549254be1b09fbd8ca 100644 --- a/example/35_splitK_gemm/splitK_gemm_xdl_fp16.cpp +++ b/example/35_splitK_gemm/splitK_gemm_xdl_fp16.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #include #include diff --git a/example/35_splitK_gemm/splitK_gemm_xdl_fp32.cpp b/example/35_splitK_gemm/splitK_gemm_xdl_fp32.cpp index bc2e3d1d52b668a2ea15c9f4ea1db44c4faeb19e..7506f694204b3c3aaa564704eda4f358baf17861 100644 --- a/example/35_splitK_gemm/splitK_gemm_xdl_fp32.cpp +++ b/example/35_splitK_gemm/splitK_gemm_xdl_fp32.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #include #include diff --git a/example/35_splitK_gemm/splitK_gemm_xdl_int4.cpp b/example/35_splitK_gemm/splitK_gemm_xdl_int4.cpp index 4eb27824628d8f193ca6108a589794041fc2b2fb..7ebf9144082de53d64f3662e579dd6d246266f1d 100644 --- a/example/35_splitK_gemm/splitK_gemm_xdl_int4.cpp +++ b/example/35_splitK_gemm/splitK_gemm_xdl_int4.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #include #include diff --git a/example/35_splitK_gemm/splitK_gemm_xdl_int8.cpp b/example/35_splitK_gemm/splitK_gemm_xdl_int8.cpp index eefdbca6b1ae3e49cc221b3037e46888b4a9a9bd..6b0c1aa02d05d848d8280510fa7d05de8b7d4287 100644 --- a/example/35_splitK_gemm/splitK_gemm_xdl_int8.cpp +++ b/example/35_splitK_gemm/splitK_gemm_xdl_int8.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #include #include @@ -30,6 +30,7 @@ using ADataType = int8_t; using BDataType = int8_t; using AccDataType = int32_t; using CDataType = int32_t; +using ComputeType = int8_t; using ALayout = Row; using BLayout = Col; @@ -43,11 +44,11 @@ static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecializa using DeviceGemmInstance = ck::tensor_operation::device::DeviceGemmXdlSplitKCShuffle // clang-format off -//######| AData| BData| CData| AccData| ALayout| BLayout| CLayout| A| B| C| GEMM| Block| MPer| NPer| KPer| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| -//######| Type| Type| Type| Type| | | | Elementwise| Elementwise| Elementwise| Spacialization| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MXdlPerWave_MWaveMPerXdl| ScalarPerVector| -//######| | | | | | | | Operation| Operation| Operation| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NXdlPerWave_NWaveNPerXdl| _NWaveNPerXdl| -//######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | - < ADataType, BDataType, CDataType, AccDataType, ALayout, BLayout, CLayout, AElementOp, BElementOp, CElementOp, GemmDefault, 256, 256, 128, 4, 16, 32, 32, 4, 2, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 16, 16, true, S<1, 4, 64, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 3, 16, 16, true, 1, 1, S<1, 32, 1, 8>, 4>; +//######| AData| BData| CData| AccData| ALayout| BLayout| CLayout| A| B| C| GEMM| Block| MPer| NPer| KPer| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| Compute| +//######| Type| Type| Type| Type| | | | Elementwise| Elementwise| Elementwise| Spacialization| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MXdlPerWave_MWaveMPerXdl| ScalarPerVector| Type| +//######| | | | | | | | Operation| Operation| Operation| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NXdlPerWave_NWaveNPerXdl| _NWaveNPerXdl| | +//######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + < ADataType, BDataType, CDataType, AccDataType, ALayout, BLayout, CLayout, AElementOp, BElementOp, CElementOp, GemmDefault, 256, 256, 128, 4, 16, 32, 32, 4, 2, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 16, 16, true, S<1, 4, 64, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 3, 16, 16, true, 1, 1, S<1, 32, 1, 8>, 4, ComputeType>; // clang-format on #include "run_splitK_gemm_example.inc" diff --git a/example/36_sparse_embedding/sparse_embedding3_forward_layernorm.cpp b/example/36_sparse_embedding/sparse_embedding3_forward_layernorm.cpp index f0a0cdf6f13df7081d15ab0099122542ccf6118e..d2337dcda5d30eb57e0e5f764f047cf21314a190 100644 --- a/example/36_sparse_embedding/sparse_embedding3_forward_layernorm.cpp +++ b/example/36_sparse_embedding/sparse_embedding3_forward_layernorm.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #include #include diff --git a/example/37_batched_gemm_add_add_relu_gemm_add/CMakeLists.txt b/example/37_batched_gemm_add_add_relu_gemm_add/CMakeLists.txt index a9be3a7108fee6075dd7ae81f7536b072e655c0b..36bb5720d53c1a738323f50d0e855d50d8996047 100644 --- a/example/37_batched_gemm_add_add_relu_gemm_add/CMakeLists.txt +++ b/example/37_batched_gemm_add_add_relu_gemm_add/CMakeLists.txt @@ -1 +1,3 @@ -add_example_executable(example_batched_gemm_add_add_relu_gemm_add_xdl_fp16 batched_gemm_add_add_relu_gemm_add_xdl_fp16.cpp) +if(DTYPES MATCHES "fp16" OR NOT DEFINED DTYPES) + add_example_executable(example_batched_gemm_add_add_relu_gemm_add_xdl_fp16 batched_gemm_add_add_relu_gemm_add_xdl_fp16.cpp) +endif() diff --git a/example/37_batched_gemm_add_add_relu_gemm_add/batched_gemm_add_add_relu_gemm_add_xdl_fp16.cpp b/example/37_batched_gemm_add_add_relu_gemm_add/batched_gemm_add_add_relu_gemm_add_xdl_fp16.cpp index 071e8a7431c057a32ea0061655d81a202774b923..36dcf58d7044b3bf54d3e318e5d095a40d70255d 100644 --- a/example/37_batched_gemm_add_add_relu_gemm_add/batched_gemm_add_add_relu_gemm_add_xdl_fp16.cpp +++ b/example/37_batched_gemm_add_add_relu_gemm_add/batched_gemm_add_add_relu_gemm_add_xdl_fp16.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. /* Computes C_m_o = Relu(A0[m, k] * B0[n, k] + D00[m, n] + D01[mn]) * B1[n, o] + D1[m, o] @@ -173,6 +173,8 @@ using DeviceGemmInstance = 8, 8, true, + 9, // D0sTransferSrcVectorDim + 4, // D0sTransferSrcScalaerPerVector S<8, 32, 1>, // B1BlockTransfer S<0, 2, 1>, S<0, 2, 1>, @@ -189,7 +191,7 @@ int main(int argc, char* argv[]) { bool do_verification = true; int init_method = 1; - bool time_kernel = false; + bool time_kernel = true; // GEMM shape ck::index_t M = 1024; diff --git a/example/38_grouped_conv_bwd_data_multiple_d/CMakeLists.txt b/example/38_grouped_conv_bwd_data_multiple_d/CMakeLists.txt index 472d59c77d64ac52f1a6f8b562312917487328ed..3821f8aaca74a3b028fc3917c10e5f279375c3fe 100644 --- a/example/38_grouped_conv_bwd_data_multiple_d/CMakeLists.txt +++ b/example/38_grouped_conv_bwd_data_multiple_d/CMakeLists.txt @@ -1,8 +1,15 @@ -if(GPU_TARGETS MATCHES "gfx908" OR GPU_TARGETS MATCHES "gfx90a" OR GPU_TARGETS MATCHES "gfx940") +if(DTYPES MATCHES "fp16" OR NOT DEFINED DTYPES) +list(APPEND gpu_list gfx908 gfx90a gfx940 gfx941 gfx942) +set(target 0) +foreach(gpu IN LISTS GPU_TARGETS) + if(gpu IN_LIST gpu_list AND target EQUAL 0) add_custom_target(example_grouped_conv_bwd_data) add_example_executable(example_grouped_conv_bwd_data_fp16 grouped_conv_bwd_data_fp16.cpp) add_example_executable(example_grouped_conv_bwd_data_bias_relu_fp16 grouped_conv_bwd_data_bias_relu_fp16.cpp) add_dependencies(example_grouped_conv_bwd_data example_grouped_conv_bwd_data_fp16) add_dependencies(example_grouped_conv_bwd_data example_grouped_conv_bwd_data_bias_relu_fp16) -endif() \ No newline at end of file + set(target 1) + endif() +endforeach() +endif() diff --git a/example/38_grouped_conv_bwd_data_multiple_d/common.hpp b/example/38_grouped_conv_bwd_data_multiple_d/common.hpp index d07ee7bdc1c4b999f61a84727bb10fe7ea6f4c15..ca824b1075ced161760892fe6de53767b1079ccc 100644 --- a/example/38_grouped_conv_bwd_data_multiple_d/common.hpp +++ b/example/38_grouped_conv_bwd_data_multiple_d/common.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/example/38_grouped_conv_bwd_data_multiple_d/grouped_conv_bwd_data_bias_relu_fp16.cpp b/example/38_grouped_conv_bwd_data_multiple_d/grouped_conv_bwd_data_bias_relu_fp16.cpp index 55ea8c3a3109469532d0e5c0566da3eb11e7353d..a3533bb4cc4b4874f0e7ae9f333347913d981ba0 100644 --- a/example/38_grouped_conv_bwd_data_multiple_d/grouped_conv_bwd_data_bias_relu_fp16.cpp +++ b/example/38_grouped_conv_bwd_data_multiple_d/grouped_conv_bwd_data_bias_relu_fp16.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #include "common.hpp" diff --git a/example/38_grouped_conv_bwd_data_multiple_d/grouped_conv_bwd_data_fp16.cpp b/example/38_grouped_conv_bwd_data_multiple_d/grouped_conv_bwd_data_fp16.cpp index ddf82ec512c61c72a1a2e2323a6645f39f15ce6d..fb688b6f3f15383c2920a12e8902c2d4cf05a109 100644 --- a/example/38_grouped_conv_bwd_data_multiple_d/grouped_conv_bwd_data_fp16.cpp +++ b/example/38_grouped_conv_bwd_data_multiple_d/grouped_conv_bwd_data_fp16.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #include "common.hpp" diff --git a/example/38_grouped_conv_bwd_data_multiple_d/run_grouped_conv_bwd_data_bias_relu_example.inc b/example/38_grouped_conv_bwd_data_multiple_d/run_grouped_conv_bwd_data_bias_relu_example.inc index 0afd8bd70da849085db6e19583e933c3d1b91968..0f0b120cbcc578b2fa7f650c37c9a28148dd834a 100644 --- a/example/38_grouped_conv_bwd_data_multiple_d/run_grouped_conv_bwd_data_bias_relu_example.inc +++ b/example/38_grouped_conv_bwd_data_multiple_d/run_grouped_conv_bwd_data_bias_relu_example.inc @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. bool run_conv_bwd_data_bias_relu(const ExecutionConfig& config, const ck::utils::conv::ConvParam& conv_params, diff --git a/example/38_grouped_conv_bwd_data_multiple_d/run_grouped_conv_bwd_data_example.inc b/example/38_grouped_conv_bwd_data_multiple_d/run_grouped_conv_bwd_data_example.inc index e50c98bbe844f6ec438c781b7d5d8e254e57d98f..25678491ce90cdf1ab9455f14e091bd5fcd67431 100644 --- a/example/38_grouped_conv_bwd_data_multiple_d/run_grouped_conv_bwd_data_example.inc +++ b/example/38_grouped_conv_bwd_data_multiple_d/run_grouped_conv_bwd_data_example.inc @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. bool run_conv_bwd_data(const ExecutionConfig& config, const ck::utils::conv::ConvParam& conv_params, diff --git a/example/39_permute/CMakeLists.txt b/example/39_permute/CMakeLists.txt index 573ad7239e608e7ba947f9ca864a0c7f4da4689b..5b43de9725aa3d535851926884fc6790233c6dbb 100644 --- a/example/39_permute/CMakeLists.txt +++ b/example/39_permute/CMakeLists.txt @@ -1,9 +1,11 @@ -add_custom_target(example_permute) +if(DTYPES MATCHES "fp16" OR NOT DEFINED DTYPES) + add_custom_target(example_permute) -add_example_executable(example_permute_1xHxW_fp16 permute_1xHxW_fp16.cpp) -add_example_executable(example_permute_NxHxW_fp16 permute_NxHxW_fp16.cpp) -add_example_executable(example_permute_HxWx4_fp16 permute_HxWx4_fp16.cpp) + add_example_executable(example_permute_1xHxW_fp16 permute_1xHxW_fp16.cpp) + add_example_executable(example_permute_NxHxW_fp16 permute_NxHxW_fp16.cpp) + add_example_executable(example_permute_HxWx4_fp16 permute_HxWx4_fp16.cpp) -add_dependencies(example_permute example_permute_1xHxW_fp16) -add_dependencies(example_permute example_permute_NxHxW_fp16) -add_dependencies(example_permute example_permute_HxWx4_fp16) + add_dependencies(example_permute example_permute_1xHxW_fp16) + add_dependencies(example_permute example_permute_NxHxW_fp16) + add_dependencies(example_permute example_permute_HxWx4_fp16) +endif() diff --git a/example/39_permute/common.hpp b/example/39_permute/common.hpp index ab612cea1794c422c737a02e5bbe7a6728904c40..54f3a788097bebc6e293389cecd03d8ae68d8494 100644 --- a/example/39_permute/common.hpp +++ b/example/39_permute/common.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/example/39_permute/permute_1xHxW_fp16.cpp b/example/39_permute/permute_1xHxW_fp16.cpp index d7f9b80544a452ca7ea062d7d6135b92a2bf19ea..7336c3b631bcadd7d23a2d6b8e0b3d3a31787516 100644 --- a/example/39_permute/permute_1xHxW_fp16.cpp +++ b/example/39_permute/permute_1xHxW_fp16.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #include "common.hpp" diff --git a/example/39_permute/permute_HxWx4_fp16.cpp b/example/39_permute/permute_HxWx4_fp16.cpp index 342aa134ec5570c84447b137e5be165c9a5f697c..6c24919ded6a0c65556e305455861ca137203b3e 100644 --- a/example/39_permute/permute_HxWx4_fp16.cpp +++ b/example/39_permute/permute_HxWx4_fp16.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #include "common.hpp" diff --git a/example/39_permute/permute_NxHxW_fp16.cpp b/example/39_permute/permute_NxHxW_fp16.cpp index b53975eb2c8632583f61afde50a03eba32c17c9b..3551d2a7c8decf31a48ca628d99a53639f4fa1fe 100644 --- a/example/39_permute/permute_NxHxW_fp16.cpp +++ b/example/39_permute/permute_NxHxW_fp16.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #include "common.hpp" diff --git a/example/39_permute/run_permute_bundle_example.inc b/example/39_permute/run_permute_bundle_example.inc index 70406d63f91b95f4d1cce025051a74a0ee3d114e..2c198729226243fb1fb97db6beefa1fd1f62f625 100644 --- a/example/39_permute/run_permute_bundle_example.inc +++ b/example/39_permute/run_permute_bundle_example.inc @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/example/39_permute/run_permute_element_example.inc b/example/39_permute/run_permute_element_example.inc index bc6235303039c7cb33a4057446137a9f18ddb186..35871344567ddf62a4fb3c2bd6a2e77cb94a1998 100644 --- a/example/39_permute/run_permute_element_example.inc +++ b/example/39_permute/run_permute_element_example.inc @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/example/40_conv2d_fwd_quantization/CMakeLists.txt b/example/40_conv2d_fwd_quantization/CMakeLists.txt index c12ab7a3407803849ff095ea3d6da28002e93a1f..55464957ac62e831e332e6b88b97e283e22866c8 100644 --- a/example/40_conv2d_fwd_quantization/CMakeLists.txt +++ b/example/40_conv2d_fwd_quantization/CMakeLists.txt @@ -1,23 +1,28 @@ -if(GPU_TARGETS MATCHES "gfx908" OR GPU_TARGETS MATCHES "gfx90a" OR GPU_TARGETS MATCHES "gfx940") +if(DTYPES MATCHES "int8" OR NOT DEFINED DTYPES) +list(APPEND gpu_list gfx908 gfx90a gfx940 gfx941 gfx942) +set(target 0) +foreach(gpu IN LISTS GPU_TARGETS) + if(gpu IN_LIST gpu_list AND target EQUAL 0) add_example_executable(example_conv2d_fwd_xdl_perlayer_quantization_int8 conv2d_fwd_xdl_perlayer_quantization_int8.cpp) add_example_executable(example_conv2d_fwd_xdl_perchannel_quantization_int8 conv2d_fwd_xdl_perchannel_quantization_int8.cpp) add_example_executable(example_conv2d_fwd_xdl_bias_relu_perlayer_quantization_int8 conv2d_fwd_xdl_bias_relu_perlayer_quantization_int8.cpp) add_example_executable(example_conv2d_fwd_xdl_bias_relu_perchannel_quantization_int8 conv2d_fwd_xdl_bias_relu_perchannel_quantization_int8.cpp) -endif() -# Conv perlayer quantization -add_example_executable(example_conv2d_fwd_dl_perlayer_quantization_int8 conv2d_fwd_dl_perlayer_quantization_int8.cpp) + set(target 1) + endif() +endforeach() -# Conv perchannel quantization -add_example_executable(example_conv2d_fwd_dl_perchannel_quantization_int8 conv2d_fwd_dl_perchannel_quantization_int8.cpp) - -# Conv + bias + relu perlayer quantization -add_example_executable(example_conv2d_fwd_dl_bias_relu_perlayer_quantization_int8 conv2d_fwd_dl_bias_relu_perlayer_quantization_int8.cpp) - -# Conv + bias + relu perchannel quantization -add_example_executable(example_conv2d_fwd_dl_bias_relu_perchannel_quantization_int8 conv2d_fwd_dl_bias_relu_perchannel_quantization_int8.cpp) - -# Conv + bias + tanh perlayer quantization -add_example_executable(example_conv2d_fwd_dl_bias_tanh_perlayer_quantization_int8 conv2d_fwd_dl_bias_tanh_perlayer_quantization_int8.cpp) - -# Conv + bias + tanh perchannel quantization -add_example_executable(example_conv2d_fwd_dl_bias_tanh_perchannel_quantization_int8 conv2d_fwd_dl_bias_tanh_perchannel_quantization_int8.cpp) \ No newline at end of file + if(DL_KERNELS) + # Conv perlayer quantization + add_example_executable(example_conv2d_fwd_dl_perlayer_quantization_int8 conv2d_fwd_dl_perlayer_quantization_int8.cpp) + # Conv perchannel quantization + add_example_executable(example_conv2d_fwd_dl_perchannel_quantization_int8 conv2d_fwd_dl_perchannel_quantization_int8.cpp) + # Conv + bias + relu perlayer quantization + add_example_executable(example_conv2d_fwd_dl_bias_relu_perlayer_quantization_int8 conv2d_fwd_dl_bias_relu_perlayer_quantization_int8.cpp) + # Conv + bias + relu perchannel quantization + add_example_executable(example_conv2d_fwd_dl_bias_relu_perchannel_quantization_int8 conv2d_fwd_dl_bias_relu_perchannel_quantization_int8.cpp) + # Conv + bias + tanh perlayer quantization + add_example_executable(example_conv2d_fwd_dl_bias_tanh_perlayer_quantization_int8 conv2d_fwd_dl_bias_tanh_perlayer_quantization_int8.cpp) + # Conv + bias + tanh perchannel quantization + add_example_executable(example_conv2d_fwd_dl_bias_tanh_perchannel_quantization_int8 conv2d_fwd_dl_bias_tanh_perchannel_quantization_int8.cpp) + endif() +endif() \ No newline at end of file diff --git a/example/40_conv2d_fwd_quantization/common.hpp b/example/40_conv2d_fwd_quantization/common.hpp index 6ee14d750ef6c9f8308ff7b17b86626fc70ca2b8..266b09145c84c456aedb2182554ced7815457410 100644 --- a/example/40_conv2d_fwd_quantization/common.hpp +++ b/example/40_conv2d_fwd_quantization/common.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/example/40_conv2d_fwd_quantization/conv2d_fwd_dl_bias_relu_perchannel_quantization_int8.cpp b/example/40_conv2d_fwd_quantization/conv2d_fwd_dl_bias_relu_perchannel_quantization_int8.cpp index 5c445d9c50b6ea3235f045a7e569aaf87fb048ea..4573c68658bf44dc49ccade302feb1547746f0f9 100644 --- a/example/40_conv2d_fwd_quantization/conv2d_fwd_dl_bias_relu_perchannel_quantization_int8.cpp +++ b/example/40_conv2d_fwd_quantization/conv2d_fwd_dl_bias_relu_perchannel_quantization_int8.cpp @@ -1,8 +1,8 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #include "common.hpp" -#include "ck/tensor_operation/gpu/device/device_grouped_conv_fwd_dl_multiple_d_nhwc_kyxc_nhwk.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_dl_multiple_d_nhwc_kyxc_nhwk.hpp" using InDataType = int8_t; using WeiDataType = int8_t; diff --git a/example/40_conv2d_fwd_quantization/conv2d_fwd_dl_bias_relu_perlayer_quantization_int8.cpp b/example/40_conv2d_fwd_quantization/conv2d_fwd_dl_bias_relu_perlayer_quantization_int8.cpp index 0ff85f008fa6fd7101436ec55c02c1e1fbae79cf..005f6263fd46e8b505fbc6859a5fa9ed334e7229 100644 --- a/example/40_conv2d_fwd_quantization/conv2d_fwd_dl_bias_relu_perlayer_quantization_int8.cpp +++ b/example/40_conv2d_fwd_quantization/conv2d_fwd_dl_bias_relu_perlayer_quantization_int8.cpp @@ -1,8 +1,8 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #include "common.hpp" -#include "ck/tensor_operation/gpu/device/device_grouped_conv_fwd_dl_multiple_d_nhwc_kyxc_nhwk.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_dl_multiple_d_nhwc_kyxc_nhwk.hpp" using InDataType = int8_t; using WeiDataType = int8_t; diff --git a/example/40_conv2d_fwd_quantization/conv2d_fwd_dl_bias_tanh_perchannel_quantization_int8.cpp b/example/40_conv2d_fwd_quantization/conv2d_fwd_dl_bias_tanh_perchannel_quantization_int8.cpp index f8f996d17e82948076cbb0fa52ff3c378f7cd7cb..62e5e583de864524d2e937683a8c60cdcf9d57da 100644 --- a/example/40_conv2d_fwd_quantization/conv2d_fwd_dl_bias_tanh_perchannel_quantization_int8.cpp +++ b/example/40_conv2d_fwd_quantization/conv2d_fwd_dl_bias_tanh_perchannel_quantization_int8.cpp @@ -1,8 +1,8 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #include "common.hpp" -#include "ck/tensor_operation/gpu/device/device_grouped_conv_fwd_dl_multiple_d_nhwc_kyxc_nhwk.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_dl_multiple_d_nhwc_kyxc_nhwk.hpp" using InDataType = int8_t; using WeiDataType = int8_t; diff --git a/example/40_conv2d_fwd_quantization/conv2d_fwd_dl_bias_tanh_perlayer_quantization_int8.cpp b/example/40_conv2d_fwd_quantization/conv2d_fwd_dl_bias_tanh_perlayer_quantization_int8.cpp index 3b25fec0c4a1bcbf89aacd2c1c506358a09000be..ef98fe7e4f945bab663b7e5ee895633186919df9 100644 --- a/example/40_conv2d_fwd_quantization/conv2d_fwd_dl_bias_tanh_perlayer_quantization_int8.cpp +++ b/example/40_conv2d_fwd_quantization/conv2d_fwd_dl_bias_tanh_perlayer_quantization_int8.cpp @@ -1,8 +1,8 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #include "common.hpp" -#include "ck/tensor_operation/gpu/device/device_grouped_conv_fwd_dl_multiple_d_nhwc_kyxc_nhwk.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_dl_multiple_d_nhwc_kyxc_nhwk.hpp" using InDataType = int8_t; using WeiDataType = int8_t; diff --git a/example/40_conv2d_fwd_quantization/conv2d_fwd_dl_perchannel_quantization_int8.cpp b/example/40_conv2d_fwd_quantization/conv2d_fwd_dl_perchannel_quantization_int8.cpp index a98a1e240bcbf390cfaed0ec9230f10d88ed90ce..e524ddb2b297ae218eaf197bea1f149e01a76efc 100644 --- a/example/40_conv2d_fwd_quantization/conv2d_fwd_dl_perchannel_quantization_int8.cpp +++ b/example/40_conv2d_fwd_quantization/conv2d_fwd_dl_perchannel_quantization_int8.cpp @@ -1,8 +1,8 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #include "common.hpp" -#include "ck/tensor_operation/gpu/device/device_grouped_conv_fwd_dl_multiple_d_nhwc_kyxc_nhwk.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_dl_multiple_d_nhwc_kyxc_nhwk.hpp" using InDataType = int8_t; using WeiDataType = int8_t; diff --git a/example/40_conv2d_fwd_quantization/conv2d_fwd_dl_perlayer_quantization_int8.cpp b/example/40_conv2d_fwd_quantization/conv2d_fwd_dl_perlayer_quantization_int8.cpp index 262594d58b3d3e591ab38209d86c3b683bd7680e..d29a3143c0ed4196a64ade05af0dc295a358c527 100644 --- a/example/40_conv2d_fwd_quantization/conv2d_fwd_dl_perlayer_quantization_int8.cpp +++ b/example/40_conv2d_fwd_quantization/conv2d_fwd_dl_perlayer_quantization_int8.cpp @@ -1,8 +1,8 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #include "common.hpp" -#include "ck/tensor_operation/gpu/device/device_grouped_conv_fwd_dl_multiple_d_nhwc_kyxc_nhwk.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_dl_multiple_d_nhwc_kyxc_nhwk.hpp" using InDataType = int8_t; using WeiDataType = int8_t; diff --git a/example/40_conv2d_fwd_quantization/conv2d_fwd_xdl_bias_relu_perchannel_quantization_int8.cpp b/example/40_conv2d_fwd_quantization/conv2d_fwd_xdl_bias_relu_perchannel_quantization_int8.cpp index 6b22055053d3f45ce6f22e03bff25e5b5b7484fa..06c839e4e226e082e9754f12ce39cececbd36269 100644 --- a/example/40_conv2d_fwd_quantization/conv2d_fwd_xdl_bias_relu_perchannel_quantization_int8.cpp +++ b/example/40_conv2d_fwd_quantization/conv2d_fwd_xdl_bias_relu_perchannel_quantization_int8.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #include "common.hpp" #include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_d_xdl_cshuffle.hpp" diff --git a/example/40_conv2d_fwd_quantization/conv2d_fwd_xdl_bias_relu_perlayer_quantization_int8.cpp b/example/40_conv2d_fwd_quantization/conv2d_fwd_xdl_bias_relu_perlayer_quantization_int8.cpp index 1ac8679743777db2c5d99370b8b22ea8004bc8d6..7a9b42d39f971ab1c36476f4c0ead172cb6d2606 100644 --- a/example/40_conv2d_fwd_quantization/conv2d_fwd_xdl_bias_relu_perlayer_quantization_int8.cpp +++ b/example/40_conv2d_fwd_quantization/conv2d_fwd_xdl_bias_relu_perlayer_quantization_int8.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #include "common.hpp" #include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_d_xdl_cshuffle.hpp" diff --git a/example/40_conv2d_fwd_quantization/conv2d_fwd_xdl_perchannel_quantization_int8.cpp b/example/40_conv2d_fwd_quantization/conv2d_fwd_xdl_perchannel_quantization_int8.cpp index f28abe5ebc92d179fd65367b028e76bd4de0cbe4..3495636297d5d83fb3ccc7d236e6135de9b715fd 100644 --- a/example/40_conv2d_fwd_quantization/conv2d_fwd_xdl_perchannel_quantization_int8.cpp +++ b/example/40_conv2d_fwd_quantization/conv2d_fwd_xdl_perchannel_quantization_int8.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #include "common.hpp" #include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_d_xdl_cshuffle.hpp" diff --git a/example/40_conv2d_fwd_quantization/conv2d_fwd_xdl_perlayer_quantization_int8.cpp b/example/40_conv2d_fwd_quantization/conv2d_fwd_xdl_perlayer_quantization_int8.cpp index f468e8adcde28b7fa0099405b76dca7b9d29ce64..2611337254212b326f6eef540c6a4eca5cd19261 100644 --- a/example/40_conv2d_fwd_quantization/conv2d_fwd_xdl_perlayer_quantization_int8.cpp +++ b/example/40_conv2d_fwd_quantization/conv2d_fwd_xdl_perlayer_quantization_int8.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #include "common.hpp" #include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_d_xdl_cshuffle.hpp" diff --git a/example/40_conv2d_fwd_quantization/run_conv2d_fwd_bias_perchannel_quantization_example.inc b/example/40_conv2d_fwd_quantization/run_conv2d_fwd_bias_perchannel_quantization_example.inc index 5675db77fae1a88a7de156a75c86f50e76f1fde0..e5b924ad5114a6e31a1ba2875118085329d5e8ce 100644 --- a/example/40_conv2d_fwd_quantization/run_conv2d_fwd_bias_perchannel_quantization_example.inc +++ b/example/40_conv2d_fwd_quantization/run_conv2d_fwd_bias_perchannel_quantization_example.inc @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #pragma once template #include diff --git a/example/41_grouped_conv_conv_fwd/grouped_conv_conv_fwd_xdl_fp16.cpp b/example/41_grouped_conv_conv_fwd/grouped_conv_conv_fwd_xdl_fp16.cpp index b7f80e76d6ce5a690beef4f48361e0a65402cb47..496e676a402f4d12a325c2afef56ec422cb4d721 100644 --- a/example/41_grouped_conv_conv_fwd/grouped_conv_conv_fwd_xdl_fp16.cpp +++ b/example/41_grouped_conv_conv_fwd/grouped_conv_conv_fwd_xdl_fp16.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #include #include diff --git a/example/41_grouped_conv_conv_fwd/grouped_conv_conv_fwd_xdl_fp32.cpp b/example/41_grouped_conv_conv_fwd/grouped_conv_conv_fwd_xdl_fp32.cpp index 15e460948ef528b697a4e3bad7d40ca332f85c15..35d50721dcea27b2bfb42eea6e39da08709dfc7d 100644 --- a/example/41_grouped_conv_conv_fwd/grouped_conv_conv_fwd_xdl_fp32.cpp +++ b/example/41_grouped_conv_conv_fwd/grouped_conv_conv_fwd_xdl_fp32.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #include #include diff --git a/example/41_grouped_conv_conv_fwd/grouped_conv_conv_fwd_xdl_int4.cpp b/example/41_grouped_conv_conv_fwd/grouped_conv_conv_fwd_xdl_int4.cpp index 2cc4c07c0d89fc83b233b50967cd595dc0e53ac7..80f6e9ae05712b1df4b05b1bb75fbff62d9ca70a 100644 --- a/example/41_grouped_conv_conv_fwd/grouped_conv_conv_fwd_xdl_int4.cpp +++ b/example/41_grouped_conv_conv_fwd/grouped_conv_conv_fwd_xdl_int4.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #ifndef CK_EXPERIMENTAL_BIT_INT_EXTENSION_INT4 #error Should compile this file with ck::int4_t support diff --git a/example/41_grouped_conv_conv_fwd/grouped_conv_conv_fwd_xdl_int8.cpp b/example/41_grouped_conv_conv_fwd/grouped_conv_conv_fwd_xdl_int8.cpp index 40ff0f69cc14cd644361f3391eea8ed52380adc0..3ade6c811ae5583bbb8a2a3b277c6f4e7f265a87 100644 --- a/example/41_grouped_conv_conv_fwd/grouped_conv_conv_fwd_xdl_int8.cpp +++ b/example/41_grouped_conv_conv_fwd/grouped_conv_conv_fwd_xdl_int8.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #include #include diff --git a/example/41_grouped_conv_conv_fwd/run_grouped_conv_conv_fwd_example.inc b/example/41_grouped_conv_conv_fwd/run_grouped_conv_conv_fwd_example.inc index a2c97f4d421f9382aa279b257a49299c4016e2c3..0722d497d8df0ebf0acf0cf798215bebcdee6e3a 100644 --- a/example/41_grouped_conv_conv_fwd/run_grouped_conv_conv_fwd_example.inc +++ b/example/41_grouped_conv_conv_fwd/run_grouped_conv_conv_fwd_example.inc @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/example/42_groupnorm/CMakeLists.txt b/example/42_groupnorm/CMakeLists.txt index e8c306ac582f56607c8bfa309c4a3f16cd64be50..bc2246a2bfd39ad933f42b5b5f47d135f6278f76 100644 --- a/example/42_groupnorm/CMakeLists.txt +++ b/example/42_groupnorm/CMakeLists.txt @@ -1,3 +1,5 @@ -add_example_executable(example_groupnorm_sigmoid_mul_fp16 groupnorm_sigmoid_mul_fp16.cpp) -add_example_executable(example_groupnorm_splitk_fp16 groupnorm_splitk_fp16.cpp) -add_example_executable(example_groupnorm_swish_fp16 groupnorm_swish_fp16.cpp) +if(DTYPES MATCHES "fp16" OR NOT DEFINED DTYPES) + add_example_executable(example_groupnorm_sigmoid_mul_fp16 groupnorm_sigmoid_mul_fp16.cpp) + add_example_executable(example_groupnorm_splitk_fp16 groupnorm_splitk_fp16.cpp) + add_example_executable(example_groupnorm_swish_fp16 groupnorm_swish_fp16.cpp) +endif() diff --git a/example/42_groupnorm/common.hpp b/example/42_groupnorm/common.hpp index 780154b26cb0287ed50ff4091aed733d59ee12dd..c8f91eb53b5b943bfbdd264b842cf7762a176b82 100644 --- a/example/42_groupnorm/common.hpp +++ b/example/42_groupnorm/common.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/example/42_groupnorm/groupnorm_sigmoid_mul_fp16.cpp b/example/42_groupnorm/groupnorm_sigmoid_mul_fp16.cpp index b07a26c4c937b7e26420d824b5d00ea7d85b4f36..cc107b63dcd87735ba40ebbe7633c61993406166 100644 --- a/example/42_groupnorm/groupnorm_sigmoid_mul_fp16.cpp +++ b/example/42_groupnorm/groupnorm_sigmoid_mul_fp16.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #include "common.hpp" diff --git a/example/42_groupnorm/groupnorm_splitk_fp16.cpp b/example/42_groupnorm/groupnorm_splitk_fp16.cpp index fd4bfe38072082cce83be08e2efac101024408db..057b240a63fc38217c095b6f3a058e1d4db6ea27 100644 --- a/example/42_groupnorm/groupnorm_splitk_fp16.cpp +++ b/example/42_groupnorm/groupnorm_splitk_fp16.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #include "common.hpp" diff --git a/example/42_groupnorm/groupnorm_swish_fp16.cpp b/example/42_groupnorm/groupnorm_swish_fp16.cpp index c52243bfb0c4724df5fe418dc741362f8a5c92d5..363f22ed4c0015c9a96aa98a8b6b5302978e3d25 100644 --- a/example/42_groupnorm/groupnorm_swish_fp16.cpp +++ b/example/42_groupnorm/groupnorm_swish_fp16.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #include "common.hpp" diff --git a/example/42_groupnorm/run_groupnorm_example.inc b/example/42_groupnorm/run_groupnorm_example.inc index d1016a3b12dbd95d4230678c4a34fcc78b105bbf..16065c8d46a42433416e21e9607dc3fcb708d817 100644 --- a/example/42_groupnorm/run_groupnorm_example.inc +++ b/example/42_groupnorm/run_groupnorm_example.inc @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/example/43_splitk_gemm_bias_e_permute/CMakeLists.txt b/example/43_splitk_gemm_bias_e_permute/CMakeLists.txt index c29f18f1627ef20fe69cb3751d3a7766ffbd236c..7e070f5357999b73ce647489c37613a4e52329f8 100644 --- a/example/43_splitk_gemm_bias_e_permute/CMakeLists.txt +++ b/example/43_splitk_gemm_bias_e_permute/CMakeLists.txt @@ -1,2 +1,6 @@ -add_example_executable(example_splitk_gemm_bias_e_permute_xdl_fp16 splitk_gemm_bias_e_permute_xdl_fp16.cpp) -add_example_executable(example_splitk_gemm_bias_e_permute_xdl_fp32 splitk_gemm_bias_e_permute_xdl_fp32.cpp) +if(DTYPES MATCHES "fp16" OR NOT DEFINED DTYPES) + add_example_executable(example_splitk_gemm_bias_e_permute_xdl_fp16 splitk_gemm_bias_e_permute_xdl_fp16.cpp) +endif() +if(DTYPES MATCHES "fp32" OR NOT DEFINED DTYPES) + add_example_executable(example_splitk_gemm_bias_e_permute_xdl_fp32 splitk_gemm_bias_e_permute_xdl_fp32.cpp) +endif() diff --git a/example/43_splitk_gemm_bias_e_permute/splitk_gemm_bias_e_permute_xdl_fp16.cpp b/example/43_splitk_gemm_bias_e_permute/splitk_gemm_bias_e_permute_xdl_fp16.cpp index 7ac4b68272e2b0a1a5cbc09fa3734e190d6897da..ebba88cf41fc493103dc78b1252e6d3ad5e19739 100644 --- a/example/43_splitk_gemm_bias_e_permute/splitk_gemm_bias_e_permute_xdl_fp16.cpp +++ b/example/43_splitk_gemm_bias_e_permute/splitk_gemm_bias_e_permute_xdl_fp16.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #include #include @@ -8,7 +8,7 @@ #include "ck/ck.hpp" #include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" -#include "ck/tensor_operation/gpu/device/device_splitk_contraction_multiple_d_xdl_cshuffle.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_splitk_contraction_multiple_d_xdl_cshuffle.hpp" #include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" #include "ck/library/utility/check_err.hpp" diff --git a/example/43_splitk_gemm_bias_e_permute/splitk_gemm_bias_e_permute_xdl_fp32.cpp b/example/43_splitk_gemm_bias_e_permute/splitk_gemm_bias_e_permute_xdl_fp32.cpp index 764e55ef558bd977fca8f25f16b2bba4c111114b..4ab26293ccfddeb5117d409d0af9c00a303ce984 100644 --- a/example/43_splitk_gemm_bias_e_permute/splitk_gemm_bias_e_permute_xdl_fp32.cpp +++ b/example/43_splitk_gemm_bias_e_permute/splitk_gemm_bias_e_permute_xdl_fp32.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #include #include @@ -8,7 +8,7 @@ #include "ck/ck.hpp" #include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" -#include "ck/tensor_operation/gpu/device/device_splitk_contraction_multiple_d_xdl_cshuffle.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_splitk_contraction_multiple_d_xdl_cshuffle.hpp" #include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" #include "ck/library/utility/check_err.hpp" diff --git a/example/44_elementwise_permute/CMakeLists.txt b/example/44_elementwise_permute/CMakeLists.txt index 0e0091a986ba6f92f49eb83dc8ae068636c7385e..877a82031598ba63c3c331b0bf7df34a2c07d5a7 100644 --- a/example/44_elementwise_permute/CMakeLists.txt +++ b/example/44_elementwise_permute/CMakeLists.txt @@ -1,2 +1,4 @@ -add_example_executable(example_elementwise_permute_4D_fp16 elementwise_permute_4D_fp16.cpp) -add_example_executable(example_elementwise_permute_4D_fp16_2d elementwise_permute_4D_fp16_2d.cpp) +if(DTYPES MATCHES "fp16" OR NOT DEFINED DTYPES) + add_example_executable(example_elementwise_permute_4D_fp16 elementwise_permute_4D_fp16.cpp) + add_example_executable(example_elementwise_permute_4D_fp16_2d elementwise_permute_4D_fp16_2d.cpp) +endif() diff --git a/example/45_elementwise_normalization/elementwise_layernorm_blockwise.cpp b/example/45_elementwise_normalization/elementwise_layernorm_blockwise.cpp index 7d6ff12eeafe8de6f45e93567959d86d748c783c..76361f87a5b58071b149cab505f7c0485abfadab 100644 --- a/example/45_elementwise_normalization/elementwise_layernorm_blockwise.cpp +++ b/example/45_elementwise_normalization/elementwise_layernorm_blockwise.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #include #include diff --git a/example/46_gemm_add_multiply/CMakeLists.txt b/example/46_gemm_add_multiply/CMakeLists.txt index bfe057e8da677f30844a742b032522a2a61a947c..cf7d81f895b82244a0ae7ac127e9df4c9a03324f 100644 --- a/example/46_gemm_add_multiply/CMakeLists.txt +++ b/example/46_gemm_add_multiply/CMakeLists.txt @@ -1,2 +1,6 @@ -add_example_executable(example_gemm_add_multiply_dl_fp16 gemm_add_multiply_dl_fp16.cpp) -add_example_executable(example_gemm_add_multiply_xdl_fp16 gemm_add_multiply_xdl_fp16.cpp) +if(DTYPES MATCHES "fp16" OR NOT DEFINED DTYPES) + if(DL_KERNELS) + add_example_executable(example_gemm_add_multiply_dl_fp16 gemm_add_multiply_dl_fp16.cpp) + endif() + add_example_executable(example_gemm_add_multiply_xdl_fp16 gemm_add_multiply_xdl_fp16.cpp) +endif() diff --git a/example/46_gemm_add_multiply/common.hpp b/example/46_gemm_add_multiply/common.hpp index 3ba78dfe47ba4cac2f647bf5cb609561e6989275..2c656cf44159a594f0b8ba3b40b7d8d936d57cb2 100644 --- a/example/46_gemm_add_multiply/common.hpp +++ b/example/46_gemm_add_multiply/common.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/example/46_gemm_add_multiply/gemm_add_multiply_dl_fp16.cpp b/example/46_gemm_add_multiply/gemm_add_multiply_dl_fp16.cpp index 28c3939fa611d536600bdce9c7609705e0373da1..58a399f226c1b59073dd000e382e673badb5e1fb 100644 --- a/example/46_gemm_add_multiply/gemm_add_multiply_dl_fp16.cpp +++ b/example/46_gemm_add_multiply/gemm_add_multiply_dl_fp16.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #include "common.hpp" #include "ck/tensor_operation/gpu/device/impl/device_gemm_multiple_d_dl.hpp" diff --git a/example/46_gemm_add_multiply/gemm_add_multiply_xdl_fp16.cpp b/example/46_gemm_add_multiply/gemm_add_multiply_xdl_fp16.cpp index d5aa41f1b6d5a97c5c1b614e2765a86fde7343e1..56417b101d9528674535ab93f6fdc7ab36d47c3d 100644 --- a/example/46_gemm_add_multiply/gemm_add_multiply_xdl_fp16.cpp +++ b/example/46_gemm_add_multiply/gemm_add_multiply_xdl_fp16.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #include "common.hpp" #include "ck/tensor_operation/gpu/device/impl/device_gemm_multiple_d_xdl_cshuffle.hpp" diff --git a/example/47_gemm_bias_softmax_gemm_permute/CMakeLists.txt b/example/47_gemm_bias_softmax_gemm_permute/CMakeLists.txt index b60789ea371c08342d1263421b117dd0b40d3035..14432f6e23d4adbf09e0df6c805a68f1a5571f69 100644 --- a/example/47_gemm_bias_softmax_gemm_permute/CMakeLists.txt +++ b/example/47_gemm_bias_softmax_gemm_permute/CMakeLists.txt @@ -1,3 +1,8 @@ -if(GPU_TARGETS MATCHES "gfx908" OR GPU_TARGETS MATCHES "gfx90a" OR GPU_TARGETS MATCHES "gfx940") +list(APPEND gpu_list gfx908 gfx90a gfx940 gfx941 gfx942) +set(target 0) +foreach(gpu IN LISTS GPU_TARGETS) + if(gpu IN_LIST gpu_list AND target EQUAL 0) add_example_executable(example_gemm_bias_softmax_gemm_permute gemm_bias_softmax_gemm_permute.cpp) -endif() + set(target 1) + endif() +endforeach() diff --git a/example/47_gemm_bias_softmax_gemm_permute/gemm_bias_softmax_gemm_permute.cpp b/example/47_gemm_bias_softmax_gemm_permute/gemm_bias_softmax_gemm_permute.cpp index 30c98e534a99fa05ff6ebec044c5b70df7419254..a90a6340a431a55c271dca3d3d0d1771382218f5 100644 --- a/example/47_gemm_bias_softmax_gemm_permute/gemm_bias_softmax_gemm_permute.cpp +++ b/example/47_gemm_bias_softmax_gemm_permute/gemm_bias_softmax_gemm_permute.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #include #include @@ -121,7 +121,8 @@ using DeviceOpInstance = 2, // CShuffleNXdlPerWavePerShuffle S<1, 32, 1, 8>, // CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock 8, // CShuffleBlockTransferScalarPerVector_NPerBlock - MaskingSpec>; // MaskingSpecialization + MaskingSpec, // MaskingSpecialization + 1>; // Ref Gemm0: fp16 in, fp32 out using ReferenceGemm0Instance = ck::tensor_operation::host::ReferenceBatchedGemm @@ -18,7 +18,45 @@ #include "ck/library/utility/literals.hpp" #include "ck/library/reference_tensor_operation/cpu/reference_pool_fwd.hpp" -template +std::vector f_tensor_strides_ncdhw(ck::index_t N_, + ck::index_t C_, + ck::index_t D, + ck::index_t H, + ck::index_t W, + TensorLayout layout) +{ + using namespace ck::literals; + (void)N_; + if constexpr(ck::is_same::value) + return {C_ * D * H * W, D * H * W, H * W, W, 1_uz}; + else if constexpr(ck::is_same::value) + return {D * C_ * H * W, 1_uz, C_ * H * W, W * C_, C_}; +}; + +template +HostTensorDescriptor f_host_tensor_descriptor(std::size_t N_, + std::size_t C_, + std::size_t D, + std::size_t H, + std::size_t W, + TensorLayout layout) +{ + using namespace ck::literals; + + if constexpr(ck::is_same::value) + { + return HostTensorDescriptor({N_, C_, D, H, W}, {C_ * D * H * W, D * H * W, H * W, W, 1_uz}); + } + else if constexpr(ck::is_same::value) + { + return HostTensorDescriptor({N_, C_, D, H, W}, + {D * C_ * H * W, 1_uz, C_ * H * W, W * C_, C_}); + } +}; + +template ; // InSrcOutDstVectorSize - - const ck::index_t Do = (Di + in_left_pad_d + in_right_pad_d - Z) / window_stride_d + 1; - const ck::index_t Ho = (Hi + in_left_pad_h + in_right_pad_h - Y) / window_stride_h + 1; - const ck::index_t Wo = (Wi + in_left_pad_w + in_right_pad_w - X) / window_stride_w + 1; + const ck::index_t Zs = (Z - 1) * window_dilation_d + 1; + const ck::index_t Ys = (Y - 1) * window_dilation_h + 1; + const ck::index_t Xs = (X - 1) * window_dilation_w + 1; + const ck::index_t Do = (Di + in_left_pad_d + in_right_pad_d - Zs) / window_stride_d + 1; + const ck::index_t Ho = (Hi + in_left_pad_h + in_right_pad_h - Ys) / window_stride_h + 1; + const ck::index_t Wo = (Wi + in_left_pad_w + in_right_pad_w - Xs) / window_stride_w + 1; const std::vector window_spatial_lengths{Z, Y, X}; const std::vector window_strides{ window_stride_d, window_stride_h, window_stride_w}; + const std::vector window_dilations{ + window_dilation_d, window_dilation_h, window_dilation_w}; const std::vector input_left_pads{in_left_pad_d, in_left_pad_h, in_left_pad_w}; const std::vector input_right_pads{in_right_pad_d, in_right_pad_h, in_right_pad_w}; - // tensor layout - auto f_host_tensor_descriptor = [](std::size_t N_, - std::size_t C_, - std::size_t D, - std::size_t H, - std::size_t W, - auto layout) { - using namespace ck::literals; - - if constexpr(ck::is_same::value) - { - return HostTensorDescriptor({N_, C_, D, H, W}, - {C_ * D * H * W, D * H * W, H * W, W, 1_uz}); - } - else if constexpr(ck::is_same::value) - { - return HostTensorDescriptor({N_, C_, D, H, W}, - {D * C_ * H * W, 1_uz, C_ * H * W, W * C_, C_}); - } - }; - Tensor in_n_c_di_hi_wi(f_host_tensor_descriptor(N, C, Di, Hi, Wi, InLayout{})); Tensor out_n_c_do_ho_wo_host( f_host_tensor_descriptor(N, C, Do, Ho, Wo, OutLayout{})); @@ -126,10 +135,11 @@ bool pool3d_test(bool do_verification, {N, C, Di, Hi, Wi}, {Z, Y, X}, {N, C, Do, Ho, Wo}, - {Di * C * Hi * Wi, 1, C * Hi * Wi, Wi * C, C}, - {Do * C * Ho * Wo, 1, C * Ho * Wo, Wo * C, C}, - {Do * C * Ho * Wo, 1, C * Ho * Wo, Wo * C, C}, + f_tensor_strides_ncdhw(N, C, Di, Hi, Wi, InLayout{}), + f_tensor_strides_ncdhw(N, C, Do, Ho, Wo, OutLayout{}), + f_tensor_strides_ncdhw(N, C, Do, Ho, Wo, OutLayout{}), window_strides, + window_dilations, input_left_pads, input_right_pads, {2, 3, 4}); @@ -165,6 +175,7 @@ bool pool3d_test(bool do_verification, out_indices_n_c_do_ho_wo_host, window_spatial_lengths, window_strides, + window_dilations, input_left_pads, input_right_pads); diff --git a/example/48_pool3d_fwd/pool3d_fwd_fp16.cpp b/example/48_pool3d_fwd/pool3d_fwd_fp16.cpp index 4d3686bcb0c1f991f29da38ec5b99700e4c8155a..b9ac61033d5e4aa26a6f4fc20dae04d04dafa106 100644 --- a/example/48_pool3d_fwd/pool3d_fwd_fp16.cpp +++ b/example/48_pool3d_fwd/pool3d_fwd_fp16.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #include @@ -27,31 +27,49 @@ static constexpr auto ReduceOpId = ck::ReduceTensorOp::AVG; static constexpr bool OutputIndex = false; static constexpr bool PropagateNan = false; +using DevicePoolFwdInstance = + ck::tensor_operation::device::DevicePool3dFwd_NDHWC_NDHWC; // InSrcOutDstVectorSize + int main() { bool do_verification = true; bool time_kernel = false; // Pool shape - ck::index_t N = 2; - ck::index_t C = 32; - ck::index_t Z = 2; - ck::index_t Y = 2; - ck::index_t X = 2; - ck::index_t Di = 30; - ck::index_t Hi = 30; - ck::index_t Wi = 30; - ck::index_t window_stride_d = 2; - ck::index_t window_stride_h = 2; - ck::index_t window_stride_w = 2; - ck::index_t in_left_pad_d = 1; - ck::index_t in_left_pad_h = 1; - ck::index_t in_left_pad_w = 1; - ck::index_t in_right_pad_d = 1; - ck::index_t in_right_pad_h = 1; - ck::index_t in_right_pad_w = 1; + ck::index_t N = 2; + ck::index_t C = 32; + ck::index_t Z = 2; + ck::index_t Y = 2; + ck::index_t X = 2; + ck::index_t Di = 30; + ck::index_t Hi = 30; + ck::index_t Wi = 30; + ck::index_t window_stride_d = 2; + ck::index_t window_stride_h = 2; + ck::index_t window_stride_w = 2; + ck::index_t window_dilation_d = 1; + ck::index_t window_dilation_h = 1; + ck::index_t window_dilation_w = 1; + ck::index_t in_left_pad_d = 1; + ck::index_t in_left_pad_h = 1; + ck::index_t in_left_pad_w = 1; + ck::index_t in_right_pad_d = 1; + ck::index_t in_right_pad_h = 1; + ck::index_t in_right_pad_w = 1; - bool pass = pool3d_test + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/utility/reduction_enums.hpp" + +#include "maxpool2d_bwd_common.hpp" + +using InDataType = ck::bhalf_t; +using OutDataType = ck::bhalf_t; +using IndexDataType = int32_t; +using ComputeDataType = float; +using DInDataType = ck::bhalf_t; +using DOutDataType = ck::bhalf_t; + +static constexpr bool PropagateNan = false; + +int main() +{ + bool do_verification = true; + bool time_kernel = false; + + // Pool shape + ck::index_t N = 1; + ck::index_t C = 1; + ck::index_t Y = 3; + ck::index_t X = 3; + ck::index_t Hi = 32; + ck::index_t Wi = 32; + ck::index_t window_stride_h = 1; + ck::index_t window_stride_w = 1; + ck::index_t window_dilation_h = 1; + ck::index_t window_dilation_w = 1; + ck::index_t in_left_pad_h = 0; + ck::index_t in_left_pad_w = 0; + ck::index_t in_right_pad_h = 0; + ck::index_t in_right_pad_w = 0; + + bool pass = maxpool_bwd_test(do_verification, + time_kernel, + N, + C, + Y, + X, + Hi, + Wi, + window_stride_h, + window_stride_w, + window_dilation_h, + window_dilation_w, + in_left_pad_h, + in_left_pad_w, + in_right_pad_h, + in_right_pad_w); + + return (pass ? 0 : 1); +} diff --git a/example/49_maxpool2d_bwd/maxpool2d_bwd_common.hpp b/example/49_maxpool2d_bwd/maxpool2d_bwd_common.hpp new file mode 100644 index 0000000000000000000000000000000000000000..2c1e6693755e3581623274c323ec34a18f99c68e --- /dev/null +++ b/example/49_maxpool2d_bwd/maxpool2d_bwd_common.hpp @@ -0,0 +1,229 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include + +#include "ck/ck.hpp" +#include "ck/utility/reduction_enums.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_pool2d_fwd_nhwc_nhwc.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_max_pool_bwd_impl.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" + +#include "ck/library/utility/check_err.hpp" +#include "ck/library/utility/device_memory.hpp" +#include "ck/library/utility/host_tensor.hpp" +#include "ck/library/utility/host_tensor_generator.hpp" +#include "ck/library/utility/literals.hpp" +#include "ck/library/reference_tensor_operation/cpu/reference_pool_fwd.hpp" +#include "ck/library/reference_tensor_operation/cpu/reference_maxpool_bwd.hpp" + +template +bool maxpool_bwd_test(bool do_verification, + bool time_kernel, + ck::index_t N, + ck::index_t C, + ck::index_t Y, + ck::index_t X, + ck::index_t Hi, + ck::index_t Wi, + ck::index_t window_stride_h, + ck::index_t window_stride_w, + ck::index_t window_dilation_h, + ck::index_t window_dilation_w, + ck::index_t in_left_pad_h, + ck::index_t in_left_pad_w, + ck::index_t in_right_pad_h, + ck::index_t in_right_pad_w) +{ + using PassThrough = ck::tensor_operation::element_wise::PassThrough; + + using DevicePoolFwdInstance = + ck::tensor_operation::device::DevicePool2dFwd_NHWC_NHWC; // InSrcOutDstVectorSize + + using DeviceMaxPoolBwdInstance = ck::tensor_operation::device:: + DeviceMaxPoolBwdImpl; + + const ck::index_t Ys = (Y - 1) * window_dilation_h + 1; + const ck::index_t Xs = (X - 1) * window_dilation_w + 1; + const ck::index_t Ho = (Hi + in_left_pad_h + in_right_pad_h - Ys) / window_stride_h + 1; + const ck::index_t Wo = (Wi + in_left_pad_w + in_right_pad_w - Xs) / window_stride_w + 1; + + const std::vector window_spatial_lengths{Y, X}; + const std::vector window_strides{window_stride_h, window_stride_w}; + const std::vector window_dilations{window_dilation_h, window_dilation_w}; + const std::vector input_left_pads{in_left_pad_h, in_left_pad_w}; + const std::vector input_right_pads{in_right_pad_h, in_right_pad_w}; + + auto f_host_tensor_descriptor = + [](std::size_t N_, std::size_t C_, std::size_t H, std::size_t W) { + using namespace ck::literals; + // reference need Tensor with NCHW order + return HostTensorDescriptor({N_, C_, H, W}, {C_ * H * W, 1_uz, W * C_, C_}); + }; + + // in + Tensor in_n_c_hi_wi(f_host_tensor_descriptor(N, C, Hi, Wi)); + + // out + Tensor out_n_c_ho_wo_host(f_host_tensor_descriptor(N, C, Ho, Wo)); + Tensor out_n_c_ho_wo_device(f_host_tensor_descriptor(N, C, Ho, Wo)); + + // indices + Tensor indices_n_c_ho_wo_device(f_host_tensor_descriptor(N, C, Ho, Wo)); + Tensor indices_n_c_ho_wo_host(f_host_tensor_descriptor(N, C, Ho, Wo)); + + // dout + Tensor dout_n_c_ho_wo(f_host_tensor_descriptor(N, C, Ho, Wo)); + + // din + Tensor din_n_c_hi_wi_host(f_host_tensor_descriptor(N, C, Hi, Wi)); + Tensor din_n_c_hi_wi_device(f_host_tensor_descriptor(N, C, Hi, Wi)); + + std::cout << "in_n_c_hi_wi: " << in_n_c_hi_wi.mDesc << std::endl; + std::cout << "out_n_c_ho_wo: " << out_n_c_ho_wo_host.mDesc << std::endl; + std::cout << "indices_n_c_ho_wo: " << indices_n_c_ho_wo_host.mDesc << std::endl; + std::cout << "dout_n_c_ho_wo: " << dout_n_c_ho_wo.mDesc << std::endl; + std::cout << "din_n_c_hi_wi: " << din_n_c_hi_wi_host.mDesc << std::endl; + + in_n_c_hi_wi.GenerateTensorValue(GeneratorTensor_3{-1.0, 1.0}); + dout_n_c_ho_wo.GenerateTensorValue(GeneratorTensor_3{-1.0, 1.0}); + + DeviceMem in_device_buf(sizeof(InDataType) * in_n_c_hi_wi.mDesc.GetElementSpaceSize()); + DeviceMem out_device_buf(sizeof(OutDataType) * + out_n_c_ho_wo_device.mDesc.GetElementSpaceSize()); + DeviceMem indices_device_buf(sizeof(IndexDataType) * + indices_n_c_ho_wo_device.mDesc.GetElementSpaceSize()); + DeviceMem dout_device_buf(sizeof(DOutDataType) * dout_n_c_ho_wo.mDesc.GetElementSpaceSize()); + DeviceMem din_device_buf(sizeof(DInDataType) * + din_n_c_hi_wi_device.mDesc.GetElementSpaceSize()); + + in_device_buf.ToDevice(in_n_c_hi_wi.mData.data()); + dout_device_buf.ToDevice(dout_n_c_ho_wo.mData.data()); + + auto pool_fwd = DevicePoolFwdInstance{}; + auto pool_fwd_invoker_ptr = pool_fwd.MakeInvokerPointer(); + auto pool_fwd_argument_ptr = pool_fwd.MakeArgumentPointer( + static_cast(in_device_buf.GetDeviceBuffer()), + static_cast(out_device_buf.GetDeviceBuffer()), + static_cast(indices_device_buf.GetDeviceBuffer()), + {N, C, Hi, Wi}, + window_spatial_lengths, + {N, C, Ho, Wo}, + {C * Hi * Wi, 1, Wi * C, C}, + {C * Ho * Wo, 1, Wo * C, C}, + {C * Ho * Wo, 1, Wo * C, C}, + window_strides, + window_dilations, + input_left_pads, + input_right_pads, + {2, 3}); + + if(!pool_fwd.IsSupportedArgument(pool_fwd_argument_ptr.get())) + { + throw std::runtime_error("wrong! pool_fwd with the specified compilation parameters does " + "not support this problem"); + } + + float ave_time_fwd = + pool_fwd_invoker_ptr->Run(pool_fwd_argument_ptr.get(), StreamConfig{nullptr, time_kernel}); + + auto pool_bwd = DeviceMaxPoolBwdInstance{}; + auto pool_bwd_invoker_ptr = pool_bwd.MakeInvokerPointer(); + auto pool_bwd_argument_ptr = pool_bwd.MakeArgumentPointer( + static_cast(dout_device_buf.GetDeviceBuffer()), + static_cast(indices_device_buf.GetDeviceBuffer()), + static_cast(din_device_buf.GetDeviceBuffer()), + dout_n_c_ho_wo.mDesc.GetElementSpaceSize(), + din_n_c_hi_wi_device.mDesc.GetElementSpaceSize(), + window_spatial_lengths, + window_strides, + window_dilations); + + if(!pool_bwd.IsSupportedArgument(pool_bwd_argument_ptr.get())) + { + throw std::runtime_error("wrong! pool_bwd with the specified compilation parameters does " + "not support this problem"); + } + + size_t pool_bwd_workspace_sz = pool_bwd.GetWorkSpaceSize(pool_bwd_argument_ptr.get()); + DeviceMem pool_bwd_workspace_device_buf(pool_bwd_workspace_sz); + pool_bwd.SetWorkSpacePointer(pool_bwd_argument_ptr.get(), + pool_bwd_workspace_device_buf.GetDeviceBuffer()); + + float ave_time_bwd = + pool_bwd_invoker_ptr->Run(pool_bwd_argument_ptr.get(), StreamConfig{nullptr, time_kernel}); + + std::cout << "Pool fwd perf: " << ave_time_fwd << " ms" << std::endl; + std::cout << "Pool bwd perf: " << ave_time_bwd << " ms" << std::endl; + + bool pass = true; + + if(do_verification) + { + using ReferencePoolingFwdInstance = + ck::tensor_operation::host::ReferencePoolingFwd<4, + 2, + InDataType, + OutDataType, + ComputeDataType, + IndexDataType, + ck::ReduceTensorOp::MAX, + PropagateNan, + true>; + + auto ref_pooling_fwd = ReferencePoolingFwdInstance{}; + auto ref_pooling_fwd_invoker = ref_pooling_fwd.MakeInvoker(); + auto ref_pooling_fwd_argument = ref_pooling_fwd.MakeArgument(in_n_c_hi_wi, + out_n_c_ho_wo_host, + indices_n_c_ho_wo_host, + window_spatial_lengths, + window_strides, + window_dilations, + input_left_pads, + input_right_pads); + ref_pooling_fwd_invoker.Run(ref_pooling_fwd_argument); + + using ReferencePoolingBwdInstance = + ck::tensor_operation::host::ReferenceMaxPoolBwd; + + auto ref_pooling_bwd = ReferencePoolingBwdInstance{}; + auto ref_pooling_bwd_invoker = ref_pooling_bwd.MakeInvoker(); + auto ref_pooling_bwd_argument = ref_pooling_bwd.MakeArgument( + dout_n_c_ho_wo, indices_n_c_ho_wo_host, din_n_c_hi_wi_host, PassThrough{}); + + ref_pooling_bwd_invoker.Run(ref_pooling_bwd_argument); + + out_device_buf.FromDevice(out_n_c_ho_wo_device.mData.data()); + indices_device_buf.FromDevice(indices_n_c_ho_wo_device.mData.data()); + din_device_buf.FromDevice(din_n_c_hi_wi_device.mData.data()); + + pass = pass && ck::utils::check_err(out_n_c_ho_wo_device, out_n_c_ho_wo_host); + pass = pass && ck::utils::check_err(indices_n_c_ho_wo_device, indices_n_c_ho_wo_host); + pass = pass && ck::utils::check_err(din_n_c_hi_wi_device, din_n_c_hi_wi_host); + } + + return (pass); +}; diff --git a/example/49_maxpool2d_bwd/maxpool2d_bwd_fp16.cpp b/example/49_maxpool2d_bwd/maxpool2d_bwd_fp16.cpp new file mode 100644 index 0000000000000000000000000000000000000000..a4f982d855bbddd7a95d683d320e00bb12da0c3d --- /dev/null +++ b/example/49_maxpool2d_bwd/maxpool2d_bwd_fp16.cpp @@ -0,0 +1,66 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/utility/reduction_enums.hpp" + +#include "maxpool2d_bwd_common.hpp" + +using InDataType = ck::half_t; +using OutDataType = ck::half_t; +using IndexDataType = int32_t; +using ComputeDataType = float; +using DInDataType = ck::half_t; +using DOutDataType = ck::half_t; + +static constexpr bool PropagateNan = false; + +int main() +{ + bool do_verification = true; + bool time_kernel = false; + + // Pool shape + ck::index_t N = 1; + ck::index_t C = 1; + ck::index_t Y = 3; + ck::index_t X = 3; + ck::index_t Hi = 32; + ck::index_t Wi = 32; + ck::index_t window_stride_h = 1; + ck::index_t window_stride_w = 1; + ck::index_t window_dilation_h = 1; + ck::index_t window_dilation_w = 1; + ck::index_t in_left_pad_h = 0; + ck::index_t in_left_pad_w = 0; + ck::index_t in_right_pad_h = 0; + ck::index_t in_right_pad_w = 0; + + bool pass = maxpool_bwd_test(do_verification, + time_kernel, + N, + C, + Y, + X, + Hi, + Wi, + window_stride_h, + window_stride_w, + window_dilation_h, + window_dilation_w, + in_left_pad_h, + in_left_pad_w, + in_right_pad_h, + in_right_pad_w); + + return (pass ? 0 : 1); +} diff --git a/example/49_maxpool2d_bwd/maxpool2d_bwd_fp32.cpp b/example/49_maxpool2d_bwd/maxpool2d_bwd_fp32.cpp new file mode 100644 index 0000000000000000000000000000000000000000..c14928a9367fc69f44260a88ca1200944f76e44e --- /dev/null +++ b/example/49_maxpool2d_bwd/maxpool2d_bwd_fp32.cpp @@ -0,0 +1,66 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/utility/reduction_enums.hpp" + +#include "maxpool2d_bwd_common.hpp" + +using InDataType = float; +using OutDataType = float; +using IndexDataType = int32_t; +using ComputeDataType = float; +using DInDataType = float; +using DOutDataType = float; + +static constexpr bool PropagateNan = false; + +int main() +{ + bool do_verification = true; + bool time_kernel = false; + + // Pool shape + ck::index_t N = 1; + ck::index_t C = 1; + ck::index_t Y = 2; + ck::index_t X = 2; + ck::index_t Hi = 32; + ck::index_t Wi = 32; + ck::index_t window_stride_h = 2; + ck::index_t window_stride_w = 2; + ck::index_t window_dilation_h = 1; + ck::index_t window_dilation_w = 1; + ck::index_t in_left_pad_h = 0; + ck::index_t in_left_pad_w = 0; + ck::index_t in_right_pad_h = 0; + ck::index_t in_right_pad_w = 0; + + bool pass = maxpool_bwd_test(do_verification, + time_kernel, + N, + C, + Y, + X, + Hi, + Wi, + window_stride_h, + window_stride_w, + window_dilation_h, + window_dilation_w, + in_left_pad_h, + in_left_pad_w, + in_right_pad_h, + in_right_pad_w); + + return (pass ? 0 : 1); +} diff --git a/example/50_put_element/CMakeLists.txt b/example/50_put_element/CMakeLists.txt new file mode 100644 index 0000000000000000000000000000000000000000..eca410008fde33b6236c2c14c10650a74640c53d --- /dev/null +++ b/example/50_put_element/CMakeLists.txt @@ -0,0 +1,3 @@ +if(DTYPES MATCHES "fp16" OR NOT DEFINED DTYPES) + add_example_executable(example_put_element_fp16 put_element_fp16.cpp) +endif() diff --git a/example/50_put_element/put_element_fp16.cpp b/example/50_put_element/put_element_fp16.cpp new file mode 100644 index 0000000000000000000000000000000000000000..791747d8756962cb2e24c606d4d9f50ba3ee3410 --- /dev/null +++ b/example/50_put_element/put_element_fp16.cpp @@ -0,0 +1,88 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_put_element_impl.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" + +#include "ck/library/utility/check_err.hpp" +#include "ck/library/utility/device_memory.hpp" +#include "ck/library/utility/host_tensor.hpp" +#include "ck/library/utility/host_tensor_generator.hpp" + +using XDataType = ck::half_t; +using YDataType = ck::half_t; +using IndexDataType = int32_t; + +using YElementwiseOp = ck::tensor_operation::element_wise::PassThrough; + +using DeviceInstance = + ck::tensor_operation::device::DevicePutElementImpl; + +int main() +{ + bool do_verification = true; + bool time_kernel = false; + + int N = 1024; + + Tensor x(HostTensorDescriptor{N}); + Tensor indices(HostTensorDescriptor{N}); + Tensor y(HostTensorDescriptor{N}); + + x.GenerateTensorValue(GeneratorTensor_3{-1.0, 1.0}); + for(int i = 0; i < N; ++i) + indices(i) = i; + + DeviceMem x_device_buf(sizeof(XDataType) * x.mDesc.GetElementSpaceSize()); + DeviceMem y_device_buf(sizeof(YDataType) * y.mDesc.GetElementSpaceSize()); + DeviceMem indices_device_buf(sizeof(IndexDataType) * indices.mDesc.GetElementSpaceSize()); + + x_device_buf.ToDevice(x.mData.data()); + indices_device_buf.ToDevice(indices.mData.data()); + + auto put_instance = DeviceInstance{}; + auto put_invoker_ptr = put_instance.MakeInvokerPointer(); + auto put_argument_ptr = put_instance.MakeArgumentPointer( + static_cast(x_device_buf.GetDeviceBuffer()), + static_cast(indices_device_buf.GetDeviceBuffer()), + static_cast(y_device_buf.GetDeviceBuffer()), + N, + N, + YElementwiseOp{}); + + if(!put_instance.IsSupportedArgument(put_argument_ptr.get())) + { + throw std::runtime_error("argument is not supported!"); + } + + float ave_time = + put_invoker_ptr->Run(put_argument_ptr.get(), StreamConfig{nullptr, time_kernel}); + + std::cout << "perf: " << ave_time << " ms" << std::endl; + + bool pass = true; + if(do_verification) + { + Tensor y_host(HostTensorDescriptor{N}); + + for(int i = 0; i < N; ++i) + { + IndexDataType idx = indices(i); + y_host(idx) = x(i); + } + + y_device_buf.FromDevice(y.mData.data()); + pass = ck::utils::check_err(y, y_host); + } + + return (pass ? 0 : 1); +} diff --git a/example/51_avgpool3d_bwd/CMakeLists.txt b/example/51_avgpool3d_bwd/CMakeLists.txt new file mode 100644 index 0000000000000000000000000000000000000000..fef0c66835a804e943346f099481c98336a96b4d --- /dev/null +++ b/example/51_avgpool3d_bwd/CMakeLists.txt @@ -0,0 +1,3 @@ +add_example_executable(example_avgpool3d_bwd_bf16 avgpool3d_bwd_bf16.cpp) +add_example_executable(example_avgpool3d_bwd_fp16 avgpool3d_bwd_fp16.cpp) +add_example_executable(example_avgpool3d_bwd_fp32 avgpool3d_bwd_fp32.cpp) diff --git a/example/51_avgpool3d_bwd/avgpool3d_bwd_bf16.cpp b/example/51_avgpool3d_bwd/avgpool3d_bwd_bf16.cpp new file mode 100644 index 0000000000000000000000000000000000000000..a5ab6ef24a4da1cb23719719a4c6d9cb8187ded3 --- /dev/null +++ b/example/51_avgpool3d_bwd/avgpool3d_bwd_bf16.cpp @@ -0,0 +1,62 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_avgpool3d_bwd_ndhwc_ndhwc.hpp" + +#include "avgpool3d_bwd_common.hpp" + +using DOutDataType = ck::bhalf_t; +using DInDataType = ck::bhalf_t; +using ComputeDataType = float; + +#if 1 +using DOutLayout = ck::tensor_layout::convolution::NDHWC; +using DInLayout = ck::tensor_layout::convolution::NDHWC; +#else +using DOutLayout = ck::tensor_layout::convolution::NCDHW; +using DInLayout = ck::tensor_layout::convolution::NCDHW; +#endif + +using DevicePoolBwdInstance = + ck::tensor_operation::device::DeviceAvgPool3dBwd_NDHWC_NDHWC; // InSrcOutDstVectorSize + +int main() +{ + std::vector window_lengths = {5, 5, 5}; + std::vector window_strides = {2, 2, 2}; + std::vector window_dilations = {2, 2, 2}; + std::vector dinput_left_pads = {0, 0, 0}; + std::vector dinput_right_pads = {0, 0, 0}; + + ck::index_t N = 1; + ck::index_t C = 16; + ck::index_t Di = 40; + ck::index_t Hi = 40; + ck::index_t Wi = 40; + + pool3d_bwd_test( + true, + false, + N, + C, + Di, + Hi, + Wi, + window_lengths, + window_strides, + window_dilations, + dinput_left_pads, + dinput_right_pads); +} diff --git a/example/51_avgpool3d_bwd/avgpool3d_bwd_common.hpp b/example/51_avgpool3d_bwd/avgpool3d_bwd_common.hpp new file mode 100644 index 0000000000000000000000000000000000000000..394f046b1e9d2136e42001f366d7423f1f5e8600 --- /dev/null +++ b/example/51_avgpool3d_bwd/avgpool3d_bwd_common.hpp @@ -0,0 +1,147 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#include + +#include "ck/ck.hpp" + +#include "ck/library/utility/check_err.hpp" +#include "ck/library/utility/device_memory.hpp" +#include "ck/library/utility/host_tensor.hpp" +#include "ck/library/utility/host_tensor_generator.hpp" +#include "ck/library/utility/literals.hpp" +#include "ck/library/reference_tensor_operation/cpu/reference_avgpool_bwd.hpp" + +template +std::vector f_tensor_strides_ncdhw(ck::index_t N_, + ck::index_t C_, + ck::index_t D, + ck::index_t H, + ck::index_t W, + TensorLayout layout) +{ + using namespace ck::literals; + (void)N_; + if constexpr(ck::is_same::value) + return {C_ * D * H * W, D * H * W, H * W, W, 1_uz}; + else if constexpr(ck::is_same::value) + return {D * C_ * H * W, 1_uz, C_ * H * W, W * C_, C_}; +}; + +template +HostTensorDescriptor f_host_tensor_descriptor(std::size_t N_, + std::size_t C_, + std::size_t D, + std::size_t H, + std::size_t W, + TensorLayout layout) +{ + using namespace ck::literals; + + if constexpr(ck::is_same::value) + { + return HostTensorDescriptor({N_, C_, D, H, W}, {C_ * D * H * W, D * H * W, H * W, W, 1_uz}); + } + else if constexpr(ck::is_same::value) + { + return HostTensorDescriptor({N_, C_, D, H, W}, + {D * C_ * H * W, 1_uz, C_ * H * W, W * C_, C_}); + } +}; + +template +bool pool3d_bwd_test(bool do_verification, + bool time_kernel, + ck::index_t N, + ck::index_t C, + ck::index_t Di, + ck::index_t Hi, + ck::index_t Wi, + std::vector window_lengths, + std::vector window_strides, + std::vector window_dilations, + std::vector dinput_left_pads, + std::vector dinput_right_pads) +{ + auto OutSpatialLength = [&](auto InSpatialLength, int index) { + ck::index_t left_pad = dinput_left_pads[index]; + ck::index_t right_pad = dinput_right_pads[index]; + ck::index_t window_len = window_lengths[index]; + ck::index_t stride = window_strides[index]; + ck::index_t dilation = window_dilations[index]; + ck::index_t eff = (window_len - 1) * dilation + 1; + return (InSpatialLength + left_pad + right_pad - eff) / stride + 1; + }; + + ck::index_t Do = OutSpatialLength(Di, 0); + ck::index_t Ho = OutSpatialLength(Hi, 1); + ck::index_t Wo = OutSpatialLength(Wi, 2); + + Tensor dout(f_host_tensor_descriptor(N, C, Do, Ho, Wo, DOutLayout{})); + Tensor din_dev(f_host_tensor_descriptor(N, C, Di, Hi, Wi, DInLayout{})); + Tensor din_host(f_host_tensor_descriptor(N, C, Di, Hi, Wi, DInLayout{})); + + std::cout << "dout: " << dout.mDesc << std::endl; + std::cout << "din_host: " << din_host.mDesc << std::endl; + + dout.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); + + DeviceMem dout_device_buf(sizeof(DOutDataType) * dout.mDesc.GetElementSpaceSize()); + DeviceMem din_device_buf(sizeof(DInDataType) * din_dev.mDesc.GetElementSpaceSize()); + + dout_device_buf.ToDevice(dout.mData.data()); + din_device_buf.SetZero(); + + auto pool = DevicePoolBwdInstance{}; + auto invoker_ptr = pool.MakeInvokerPointer(); + auto argument_ptr = + pool.MakeArgumentPointer(static_cast(dout_device_buf.GetDeviceBuffer()), + static_cast(din_device_buf.GetDeviceBuffer()), + {N, C, Do, Ho, Wo}, + {N, C, Di, Hi, Wi}, + f_tensor_strides_ncdhw(N, C, Do, Ho, Wo, DOutLayout{}), + f_tensor_strides_ncdhw(N, C, Di, Hi, Wi, DInLayout{}), + window_lengths, + window_strides, + window_dilations, + dinput_left_pads, + dinput_right_pads); + + if(!pool.IsSupportedArgument(argument_ptr.get())) + { + throw std::runtime_error("wrong! device_op with the specified compilation parameters does " + "not support this problem"); + } + + float ave_time = invoker_ptr->Run(argument_ptr.get(), StreamConfig{nullptr, time_kernel}); + std::cout << "Perf: " << ave_time << std::endl; + + bool pass = true; + + if(do_verification) + { + auto ref_pool = + ck::tensor_operation::host::ReferenceAvgPoolBwd<3, DInDataType, DOutDataType>(); + + auto ref_invoker = ref_pool.MakeInvoker(); + + auto ref_argument = ref_pool.MakeArgument(din_host, + dout, + window_lengths, + window_strides, + window_dilations, + dinput_left_pads, + dinput_right_pads); + + ref_invoker.Run(ref_argument); + + din_device_buf.FromDevice(din_dev.mData.data()); + pass = ck::utils::check_err(din_dev, din_host); + } + + return pass; +} diff --git a/example/51_avgpool3d_bwd/avgpool3d_bwd_fp16.cpp b/example/51_avgpool3d_bwd/avgpool3d_bwd_fp16.cpp new file mode 100644 index 0000000000000000000000000000000000000000..578f563d5c29e8818f9319e7fb79920a785f58c5 --- /dev/null +++ b/example/51_avgpool3d_bwd/avgpool3d_bwd_fp16.cpp @@ -0,0 +1,62 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_avgpool3d_bwd_ndhwc_ndhwc.hpp" + +#include "avgpool3d_bwd_common.hpp" + +using DOutDataType = ck::half_t; +using DInDataType = ck::half_t; +using ComputeDataType = float; + +#if 1 +using DOutLayout = ck::tensor_layout::convolution::NDHWC; +using DInLayout = ck::tensor_layout::convolution::NDHWC; +#else +using DOutLayout = ck::tensor_layout::convolution::NCDHW; +using DInLayout = ck::tensor_layout::convolution::NCDHW; +#endif + +using DevicePoolBwdInstance = + ck::tensor_operation::device::DeviceAvgPool3dBwd_NDHWC_NDHWC; // InSrcOutDstVectorSize + +int main() +{ + std::vector window_lengths = {5, 5, 5}; + std::vector window_strides = {2, 2, 2}; + std::vector window_dilations = {2, 2, 2}; + std::vector dinput_left_pads = {0, 0, 0}; + std::vector dinput_right_pads = {0, 0, 0}; + + ck::index_t N = 1; + ck::index_t C = 16; + ck::index_t Di = 40; + ck::index_t Hi = 40; + ck::index_t Wi = 40; + + pool3d_bwd_test( + true, + false, + N, + C, + Di, + Hi, + Wi, + window_lengths, + window_strides, + window_dilations, + dinput_left_pads, + dinput_right_pads); +} diff --git a/example/51_avgpool3d_bwd/avgpool3d_bwd_fp32.cpp b/example/51_avgpool3d_bwd/avgpool3d_bwd_fp32.cpp new file mode 100644 index 0000000000000000000000000000000000000000..c2910c55679ca8fefce81cc0c07b7d6a439af182 --- /dev/null +++ b/example/51_avgpool3d_bwd/avgpool3d_bwd_fp32.cpp @@ -0,0 +1,62 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_avgpool3d_bwd_ndhwc_ndhwc.hpp" + +#include "avgpool3d_bwd_common.hpp" + +using DOutDataType = float; +using DInDataType = float; +using ComputeDataType = float; + +#if 1 +using DOutLayout = ck::tensor_layout::convolution::NDHWC; +using DInLayout = ck::tensor_layout::convolution::NDHWC; +#else +using DOutLayout = ck::tensor_layout::convolution::NCDHW; +using DInLayout = ck::tensor_layout::convolution::NCDHW; +#endif + +using DevicePoolBwdInstance = + ck::tensor_operation::device::DeviceAvgPool3dBwd_NDHWC_NDHWC; // InSrcOutDstVectorSize + +int main() +{ + std::vector window_lengths = {5, 5, 5}; + std::vector window_strides = {2, 2, 2}; + std::vector window_dilations = {2, 2, 2}; + std::vector dinput_left_pads = {0, 0, 0}; + std::vector dinput_right_pads = {0, 0, 0}; + + ck::index_t N = 1; + ck::index_t C = 16; + ck::index_t Di = 40; + ck::index_t Hi = 40; + ck::index_t Wi = 40; + + pool3d_bwd_test( + true, + false, + N, + C, + Di, + Hi, + Wi, + window_lengths, + window_strides, + window_dilations, + dinput_left_pads, + dinput_right_pads); +} diff --git a/example/52_image_to_column/CMakeLists.txt b/example/52_image_to_column/CMakeLists.txt new file mode 100644 index 0000000000000000000000000000000000000000..226e1fc5ae268a9c34ab0b2e9d7c0c3bf7f84520 --- /dev/null +++ b/example/52_image_to_column/CMakeLists.txt @@ -0,0 +1,10 @@ +list(APPEND gpu_list gfx908 gfx90a gfx940 gfx941 gfx942) +set(target 0) +foreach(gpu IN LISTS GPU_TARGETS) + if(gpu IN_LIST gpu_list AND target EQUAL 0) + add_custom_target(example_image_to_column) + add_example_executable(example_image_to_column_f32 image_to_column_f32.cpp) + add_dependencies(example_image_to_column example_image_to_column_f32) + set(target 1) + endif() +endforeach() diff --git a/example/52_image_to_column/common.hpp b/example/52_image_to_column/common.hpp new file mode 100644 index 0000000000000000000000000000000000000000..8510fa1e6d712d900a4faa113fa2dde370d19ddf --- /dev/null +++ b/example/52_image_to_column/common.hpp @@ -0,0 +1,95 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include +#include +#include +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_image_to_column_impl.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" + +#include "ck/library/utility/algorithm.hpp" +#include "ck/library/utility/check_err.hpp" +#include "ck/library/utility/convolution_host_tensor_descriptor_helper.hpp" +#include "ck/library/utility/convolution_parameter.hpp" +#include "ck/library/utility/device_memory.hpp" +#include "ck/library/utility/host_tensor.hpp" +#include "ck/library/utility/host_tensor_generator.hpp" +#include "ck/library/reference_tensor_operation/cpu/reference_image_to_column.hpp" + +template +using S = ck::Sequence; + +static inline constexpr ck::index_t NDimSpatial = 2; + +using FP32 = float; + +struct ExecutionConfig final +{ + bool do_verification = true; + int init_method = 1; + bool time_kernel = true; +}; + +#define DefaultConvParams \ + ck::utils::conv::ConvParam \ + { \ + NDimSpatial, 1, 32, 1, 1, {4, 4}, {64, 64}, {1, 1}, {1, 1}, {0, 0}, { 0, 0 } \ + } + +inline void print_help_msg() +{ + std::cerr << "arg1: verification (0=no, 1=yes)\n" + << "arg2: initialization (0=no init, 1=integer value, 2=decimal value)\n" + << "arg3: time kernel (0=no, 1=yes)\n" + << ck::utils::conv::get_conv_param_parser_helper_msg() << std::endl; +} + +inline bool parse_cmd_args(int argc, + char* argv[], + ExecutionConfig& config, + ck::utils::conv::ConvParam& conv_params) +{ + constexpr int num_execution_config_args = + 3; // arguments for do_verification, init_method, time_kernel + constexpr int num_conv_param_leading_args = 5; // arguments for num_dim_spatial_, G_, N_, K_, C_ + + constexpr int threshold_to_catch_partial_args = 1 + num_execution_config_args; + constexpr int threshold_to_catch_all_args = + threshold_to_catch_partial_args + num_conv_param_leading_args; + + if(argc == 1) + { + // use default + config = ExecutionConfig{}; + } + // catch only ExecutionConfig arguments + else if(argc == threshold_to_catch_partial_args) + { + config.do_verification = std::stoi(argv[1]); + config.init_method = std::stoi(argv[2]); + config.time_kernel = std::stoi(argv[3]); + } + // catch both ExecutionConfig & ConvParam arguments + else if(threshold_to_catch_all_args < argc && ((argc - threshold_to_catch_all_args) % 3 == 0)) + { + config.do_verification = std::stoi(argv[1]); + config.init_method = std::stoi(argv[2]); + config.time_kernel = std::stoi(argv[3]); + + const ck::index_t num_dim_spatial = std::stoi(argv[4]); + conv_params = ck::utils::conv::parse_conv_param( + num_dim_spatial, threshold_to_catch_partial_args, argv); + } + else + { + print_help_msg(); + return false; + } + + return true; +} diff --git a/example/52_image_to_column/image_to_column_f32.cpp b/example/52_image_to_column/image_to_column_f32.cpp new file mode 100644 index 0000000000000000000000000000000000000000..c8a7e5f221606f6ff12f860db162eb083def3921 --- /dev/null +++ b/example/52_image_to_column/image_to_column_f32.cpp @@ -0,0 +1,166 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#include "common.hpp" + +using InDataType = FP32; +using OutDataType = FP32; + +using InLayout = ck::tensor_layout::convolution::GNHWC; + +// clang-format off +using DeviceImgToColInstance = ck::tensor_operation::device::DeviceImageToColumnImpl + //#####################| Num| InLayout| InDataType| OutDataType| Block| MPer| KPer| Thread| Scalar| + //#####################| Dim| | | | Size| Block| Block| Cluster| Per| + //#####################| Spatial| | | | | | | Lengths| Vector| + //#####################| | | | | | | | | | + < NDimSpatial, InLayout, InDataType, OutDataType, 256, 128, 128, S<16, 16>, 1>; +// clang-format on + +bool RunImageToColumn(const ExecutionConfig& config, const ck::utils::conv::ConvParam& conv_params) +{ + + const auto N = conv_params.N_; + const auto C = conv_params.C_; + + const ck::index_t NDoHoWo = + N * ck::accumulate_n( + conv_params.output_spatial_lengths_.begin(), NDimSpatial, 1, std::multiplies<>()); + const ck::index_t CZYX = + C * ck::accumulate_n( + conv_params.filter_spatial_lengths_.begin(), NDimSpatial, 1, std::multiplies<>()); + + const auto in_desc = + ck::utils::conv::make_input_host_tensor_descriptor_g_n_c_wis_packed(conv_params); + const auto out_desc = HostTensorDescriptor({NDoHoWo, CZYX}); + + std::array input_spatial_lengths{}; + std::array filter_spatial_lengths{}; + std::array output_spatial_lengths{}; + std::array input_g_n_c_wis_strides{}; + std::array output_m_k_strides{}; + std::array conv_filter_strides{}; + std::array conv_filter_dilations{}; + std::array input_left_pads{}; + std::array input_right_pads{}; + + auto copy = [](const auto& x, auto& y) { std::copy(x.begin(), x.end(), y.begin()); }; + + copy(conv_params.input_spatial_lengths_, input_spatial_lengths); + copy(conv_params.filter_spatial_lengths_, filter_spatial_lengths); + copy(conv_params.output_spatial_lengths_, output_spatial_lengths); + copy(in_desc.GetStrides(), input_g_n_c_wis_strides); + copy(out_desc.GetStrides(), output_m_k_strides); + copy(conv_params.conv_filter_strides_, conv_filter_strides); + copy(conv_params.conv_filter_dilations_, conv_filter_dilations); + copy(conv_params.input_left_pads_, input_left_pads); + copy(conv_params.input_right_pads_, input_right_pads); + + Tensor in(in_desc); + Tensor out_device(out_desc); + Tensor out_host(out_desc); + + std::cout << "in: " << in.mDesc << std::endl; + std::cout << "out: " << out_device.mDesc << std::endl; + + switch(config.init_method) + { + case 0: break; + case 1: in.GenerateTensorValue(GeneratorTensor_2{-5, 5}); break; + default: in.GenerateTensorValue(GeneratorTensor_3{-0.5, 0.5}); + } + + DeviceMem in_device_buf(sizeof(InDataType) * in.mDesc.GetElementSpaceSize()); + DeviceMem out_device_buf(sizeof(OutDataType) * out_device.mDesc.GetElementSpaceSize()); + + in_device_buf.ToDevice(in.mData.data()); + + // reset input to zero + out_device_buf.SetZero(); + + static_assert(std::is_default_constructible_v); + + // do conv + auto img2col = DeviceImgToColInstance{}; + auto invoker = img2col.MakeInvoker(); + auto argument = img2col.MakeArgument(in_device_buf.GetDeviceBuffer(), + out_device_buf.GetDeviceBuffer(), + N, + C, + input_spatial_lengths, + filter_spatial_lengths, + output_spatial_lengths, + input_g_n_c_wis_strides, + output_m_k_strides, + conv_filter_strides, + conv_filter_dilations, + input_left_pads, + input_right_pads); + + if(!img2col.IsSupportedArgument(argument)) + { + std::cerr << "wrong! device_img2col with the specified compilation parameters does " + "not support this img2col problem" + << std::endl; + + return false; + } + + float ave_time = invoker.Run(argument, StreamConfig{nullptr, config.time_kernel}); + std::size_t num_btype = NDoHoWo * CZYX * (sizeof(OutDataType) + sizeof(InDataType)); + float gb_per_sec = num_btype / 1.E6 / ave_time; + std::cout << "Perf: " << ave_time << " ms, " << gb_per_sec << " GB/s" << std::endl; + + if(config.do_verification) + { + auto ref_image_to_column = ck::tensor_operation::host:: + ReferenceImageToColumn(); + + auto ref_invoker = ref_image_to_column.MakeInvoker(); + + auto ref_argument = ref_image_to_column.MakeArgument(in, + out_host, + conv_params.filter_spatial_lengths_, + conv_params.conv_filter_strides_, + conv_params.conv_filter_dilations_, + conv_params.input_left_pads_, + conv_params.input_right_pads_); + + if(!ref_image_to_column.IsSupportedArgument(&ref_argument)) + { + std::cerr << "wrong! ref_img2col with the specified compilation parameters does " + "not support this img2col problem" + << std::endl; + return false; + } + + ref_invoker.Run(ref_argument); + + out_device_buf.FromDevice(out_device.mData.data()); + + return ck::utils::check_err(out_device.mData, out_host.mData); + } + + return true; +} + +int RunImageToColumnExample(int argc, char* argv[]) +{ + ExecutionConfig config; + ck::utils::conv::ConvParam conv_params = DefaultConvParams; + + if(!parse_cmd_args(argc, argv, config, conv_params)) + { + return EXIT_FAILURE; + } + + if(conv_params.num_dim_spatial_ != NDimSpatial) + { + std::cerr << "unsupported # of spatial dimensions" << std::endl; + return EXIT_FAILURE; + } + + return !RunImageToColumn(config, conv_params); +} + +int main(int argc, char* argv[]) { return RunImageToColumnExample(argc, argv); } diff --git a/include/ck/ck.hpp b/include/ck/ck.hpp index 1626597ed2164311535e1e570fcdeb9c44cc3fd1..069ff7fc748faaee630a6aaf66eafe4e6f28fd2b 100644 --- a/include/ck/ck.hpp +++ b/include/ck/ck.hpp @@ -1,8 +1,10 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #pragma once +#include "ck/config.h" + #ifndef CK_DONT_USE_HIP_RUNTIME_HEADERS #include "hip/hip_runtime.h" #include "hip/hip_fp16.h" @@ -27,11 +29,27 @@ #define CK_WAVELET_MIN_BLOCK_PER_CU 2 #endif +// kernel attribute: amdgpu_waves_per_eu() +#ifdef CK_USE_WAVES_PER_EU +// for 1-wave kernels, control arguments of amdgpu_waves_per_eu() attribute +#ifndef CK_MIN_WAVES_PER_EU +#define CK_MIN_WAVES_PER_EU 0 +#endif + +#ifndef CK_MAX_WAVES_PER_EU +#define CK_MAX_WAVES_PER_EU 0 +#endif + +#else +#define CK_USE_WAVES_PER_EU 0 +#endif + // buffer resource #ifndef __HIP_DEVICE_COMPILE__ // for host code #define CK_BUFFER_RESOURCE_3RD_DWORD -1 #elif defined(__gfx803__) || defined(__gfx900__) || defined(__gfx906__) || defined(__gfx908__) || \ - defined(__gfx90a__) || defined(__gfx940__) // for GPU code + defined(__gfx90a__) || defined(__gfx940__) || defined(__gfx941__) || \ + defined(__gfx942__) // for GPU code #define CK_BUFFER_RESOURCE_3RD_DWORD 0x00020000 #elif defined(__gfx1030__) // for GPU code #define CK_BUFFER_RESOURCE_3RD_DWORD 0x31014000 @@ -44,7 +62,7 @@ #elif defined(__gfx803__) || defined(__gfx900__) // for GPU code #define CK_USE_AMD_V_MAC_F32 #elif defined(__gfx906__) || defined(__gfx908__) || defined(__gfx90a__) || defined(__gfx1030__) || \ - defined(__gfx940__) // for GPU code + defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__) // for GPU code #define CK_USE_AMD_V_FMAC_F32 #define CK_USE_AMD_V_DOT2_F32_F16 #define CK_USE_AMD_V_DOT4_I32_I8 @@ -53,15 +71,16 @@ // MFMA instruction #ifndef __HIP_DEVICE_COMPILE__ // for host code #define CK_USE_AMD_MFMA -#elif defined(__gfx908__) || defined(__gfx90a__) || defined(__gfx940__) // for GPU code +#elif defined(__gfx908__) || defined(__gfx90a__) || defined(__gfx940__) || defined(__gfx941__) || \ + defined(__gfx942__) // for GPU code #define CK_USE_AMD_MFMA #endif -#if(defined(__gfx90a__) || defined(__gfx940__)) +#if(defined(__gfx90a__) || defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__)) #define CK_USE_AMD_MFMA_BF16_1K_OP #endif -#if defined(__gfx940__) +#if defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__) #define CK_USE_AMD_MFMA_GFX940 #endif @@ -84,13 +103,15 @@ // buffer atomic add: floating point #ifndef __HIP_DEVICE_COMPILE__ // for host code #define CK_USE_AMD_BUFFER_ATOMIC_ADD_FLOAT 1 -#elif defined(__gfx908__) || defined(__gfx90a__) || defined(__gfx940__) // for GPU code +#elif defined(__gfx908__) || defined(__gfx90a__) || defined(__gfx940__) || defined(__gfx941__) || \ + defined(__gfx942__) // for GPU code #define CK_USE_AMD_BUFFER_ATOMIC_ADD_FLOAT 1 #else // for GPU code #define CK_USE_AMD_BUFFER_ATOMIC_ADD_FLOAT 0 #endif -#if(defined(__gfx90a__) || defined(__gfx940__)) // for GPU code +#if(defined(__gfx90a__) || defined(__gfx940__) || defined(__gfx941__) || \ + defined(__gfx942__)) // for GPU code #define CK_USE_AMD_BUFFER_ATOMIC_MAX_FLOAT64 1 #else #define CK_USE_AMD_BUFFER_ATOMIC_MAX_FLOAT64 0 @@ -99,8 +120,15 @@ // inline asm #define CK_USE_AMD_INLINE_ASM 1 -// inner product (DLOP) -#define CK_USE_AMD_INNER_PRODUCT_INLINE_ASM 1 +// inner product (V_MAC/V_FMAC) +#define CK_USE_AMD_V_MAC_INLINE_ASM 1 + +// V_DOT inline instructions, less efficient since they require adding +// `s_nop`s to avoid hazard +#define CK_USE_AMD_V_DOT_INLINE_ASM 0 + +// inner product using V_DOT with DPP8 modifiers +#define CK_USE_AMD_V_DOT_DPP8_INLINE_ASM 1 // block synchronization only s_wait lgkmcnt(0), not vmcnt(0) #define CK_EXPERIMENTAL_BLOCK_SYNC_LDS_WITHOUT_SYNC_VMEM 1 @@ -144,6 +172,10 @@ #define CK_EXPERIMENTAL_INTER_WAVE_INSTANCES 1 // experimental feature: add instances using pipeline v2 #define CK_EXPERIMENTAL_PIPELINE_V2_INSTANCES 1 +// experimental feature: optimize pipeline v2 by IGLP strategy (value=ID of strategy) +#ifndef CK_EXPERIMENTAL_PIPELINE_V2_IGLP_OPT +#define CK_EXPERIMENTAL_PIPELINE_V2_IGLP_OPT 0 +#endif // hack: have underlying assumption that need to be satsified, otherwise it's a bug // hack for forcing register to keep idx_diff_low_const in SGPR. idx_diff_low_const must be @@ -169,6 +201,7 @@ // workaround: compiler issue on gfx908 #define CK_WORKAROUND_SWDEV_388832 1 + // flag to enable (1) or disable (0) the debugging output in some kernels #define DEBUG_LOG 0 diff --git a/include/ck/config.h.in b/include/ck/config.h.in new file mode 100644 index 0000000000000000000000000000000000000000..13dc5da5d179443b603cce5a2aa6274061856db3 --- /dev/null +++ b/include/ck/config.h.in @@ -0,0 +1,102 @@ +/******************************************************************************* + * + * MIT License + * + * Copyright (c) 2023 Advanced Micro Devices, Inc. + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to deal + * in the Software without restriction, including without limitation the rights + * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + * copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in all + * copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + * SOFTWARE. + * + *******************************************************************************/ +#ifndef CK_CONFIG_H_IN +#define CK_CONFIG_H_IN + +// clang-format off +// +// DataType supports in the current CK build +// +#ifndef DTYPES +#cmakedefine DTYPES "@DTYPES@" +#endif +// if DTYPES is not defined, enable all datatypes in headerfiles +#ifndef CK_ENABLE_ALL_DTYPES +#cmakedefine CK_ENABLE_ALL_DTYPES @CK_ENABLE_ALL_DTYPES@ +#if defined(CK_ENABLE_ALL_DTYPES) +#ifndef CK_ENABLE_INT8 +#define CK_ENABLE_INT8 "ON" +#endif +#ifndef CK_ENABLE_FP8 +#define CK_ENABLE_FP8 "ON" +#endif +#ifndef CK_ENABLE_FP16 +#define CK_ENABLE_FP16 "ON" +#endif +#ifndef CK_ENABLE_BF16 +#define CK_ENABLE_BF16 "ON" +#endif +#ifndef CK_ENABLE_FP32 +#define CK_ENABLE_FP32 "ON" +#endif +#ifndef CK_ENABLE_FP64 +#define CK_ENABLE_FP64 "ON" +#endif +#endif +#endif +// if DTYPES are selectively enabled +#ifndef CK_ENABLE_INT8 +#cmakedefine CK_ENABLE_INT8 @CK_ENABLE_INT8@ +#endif + +#ifndef CK_ENABLE_FP8 +#cmakedefine CK_ENABLE_FP8 @CK_ENABLE_FP8@ +#endif + +#ifndef CK_ENABLE_FP16 +#cmakedefine CK_ENABLE_FP16 @CK_ENABLE_FP16@ +#endif + +#ifndef CK_ENABLE_BF16 +#cmakedefine CK_ENABLE_BF16 @CK_ENABLE_BF16@ +#endif + +#ifndef CK_ENABLE_FP32 +#cmakedefine CK_ENABLE_FP32 @CK_ENABLE_FP32@ +#endif + +#ifndef CK_ENABLE_FP64 +#cmakedefine CK_ENABLE_FP64 @CK_ENABLE_FP64@ +#endif + +// +// Legacy DL kernel supports in the current CK build +// by default DL kernels are turned OFF +// +#ifndef CK_ENABLE_DL_KERNELS +#cmakedefine CK_ENABLE_DL_KERNELS @CK_ENABLE_DL_KERNELS@ +#endif + +// +// Instances supports in the current CK build +// +#ifndef CK_ENABLE_INSTANCES_ONLY +#cmakedefine CK_ENABLE_INSTANCES_ONLY @CK_ENABLE_INSTANCES_ONLY@ +#endif + +// clang-format on + +#endif // CK_CONFIG_H_IN diff --git a/include/ck/host_utility/device_prop.hpp b/include/ck/host_utility/device_prop.hpp index e2cbdb733272d37aa5dbc7a746e86911d6b8644f..be1dbc1657e5bf2ac15d37184175c16b79f0b8eb 100644 --- a/include/ck/host_utility/device_prop.hpp +++ b/include/ck/host_utility/device_prop.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #pragma once @@ -51,4 +51,11 @@ inline std::string get_device_name() return name; } +inline bool is_xdl_supported() +{ + return ck::get_device_name() == "gfx908" || ck::get_device_name() == "gfx90a" || + ck::get_device_name() == "gfx940" || ck::get_device_name() == "gfx941" || + ck::get_device_name() == "gfx942"; +} + } // namespace ck diff --git a/include/ck/host_utility/hip_check_error.hpp b/include/ck/host_utility/hip_check_error.hpp index d3dc8eaf1eb8b87207256ea4a521e23d5a49ca9c..af7bebd9d6afbf59f1852156b1708ed843d9a9f6 100644 --- a/include/ck/host_utility/hip_check_error.hpp +++ b/include/ck/host_utility/hip_check_error.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/include/ck/host_utility/io.hpp b/include/ck/host_utility/io.hpp index ac8719592db8b48996bfbc65dc656d6d96bde545..55734bab2e469a12cd02c417cbb25d9c8a49727c 100644 --- a/include/ck/host_utility/io.hpp +++ b/include/ck/host_utility/io.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/include/ck/host_utility/kernel_launch.hpp b/include/ck/host_utility/kernel_launch.hpp index 24f2121674c7d371a7d13c0b1c83e7f0a5e059be..3d27103dcb877eba900aa120f545aa8d1bd5920a 100644 --- a/include/ck/host_utility/kernel_launch.hpp +++ b/include/ck/host_utility/kernel_launch.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #pragma once @@ -73,3 +73,72 @@ float launch_and_time_kernel(const StreamConfig& stream_config, return 0; #endif } + +template +float launch_and_time_kernel_with_preprocess(const StreamConfig& stream_config, + PreProcessFunc preprocess, + F kernel, + dim3 grid_dim, + dim3 block_dim, + std::size_t lds_byte, + Args... args) +{ +#if CK_TIME_KERNEL + if(stream_config.time_kernel_) + { +#if DEBUG_LOG + printf("%s: grid_dim {%d, %d, %d}, block_dim {%d, %d, %d} \n", + __func__, + grid_dim.x, + grid_dim.y, + grid_dim.z, + block_dim.x, + block_dim.y, + block_dim.z); + + printf("Warm up 1 time\n"); +#endif + // warm up + preprocess(); + kernel<<>>(args...); + + const int nrepeat = 10; +#if DEBUG_LOG + printf("Start running %d times...\n", nrepeat); +#endif + hipEvent_t start, stop; + + hip_check_error(hipEventCreate(&start)); + hip_check_error(hipEventCreate(&stop)); + + hip_check_error(hipDeviceSynchronize()); + hip_check_error(hipEventRecord(start, stream_config.stream_id_)); + + for(int i = 0; i < nrepeat; ++i) + { + preprocess(); + kernel<<>>(args...); + } + + hip_check_error(hipEventRecord(stop, stream_config.stream_id_)); + hip_check_error(hipEventSynchronize(stop)); + + float total_time = 0; + + hip_check_error(hipEventElapsedTime(&total_time, start, stop)); + + return total_time / nrepeat; + } + else + { + preprocess(); + kernel<<>>(args...); + + return 0; + } +#else + kernel<<>>(args...); + + return 0; +#endif +} diff --git a/include/ck/host_utility/stream_utility.hpp b/include/ck/host_utility/stream_utility.hpp new file mode 100644 index 0000000000000000000000000000000000000000..9ab49489bbe59b5152ad61c74622be48c33c28a7 --- /dev/null +++ b/include/ck/host_utility/stream_utility.hpp @@ -0,0 +1,43 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include + +#include "ck/stream_config.hpp" +#include "ck/host_utility/hip_check_error.hpp" + +static inline int getAvailableComputeUnitCount(const StreamConfig& stream_config) +{ + constexpr int MAX_MASK_DWORDS = 64; + + // assume at most 64*32 = 2048 CUs + uint32_t cuMask[MAX_MASK_DWORDS]; + + for(int i = 0; i < MAX_MASK_DWORDS; i++) + cuMask[i] = 0; + + auto countSetBits = [](uint32_t dword) { + int count = 0; + + while(dword != 0) + { + if(dword & 0x1) + count++; + + dword = dword >> 1; + }; + + return (count); + }; + + hip_check_error(hipExtStreamGetCUMask(stream_config.stream_id_, MAX_MASK_DWORDS, &cuMask[0])); + + int ret = 0; + + for(int i = 0; i < MAX_MASK_DWORDS; i++) + ret += countSetBits(cuMask[i]); + + return (ret); +}; diff --git a/include/ck/problem_transform/transform_forward_convolution3d_into_gemm_v4r4r4_ndhwc_kzyxc_ndhwk.hpp b/include/ck/problem_transform/transform_forward_convolution3d_into_gemm_v4r4r4_ndhwc_kzyxc_ndhwk.hpp index 41267536551ea298e111f6d93b8e6f3f8f9ed475..6b118e972e67897f2c9e0cad3a8f959760f82f6c 100644 --- a/include/ck/problem_transform/transform_forward_convolution3d_into_gemm_v4r4r4_ndhwc_kzyxc_ndhwk.hpp +++ b/include/ck/problem_transform/transform_forward_convolution3d_into_gemm_v4r4r4_ndhwc_kzyxc_ndhwk.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #ifndef CK_TRANSFORM_FORWARD_CONVOLUTION3D_INTO_GEMM_V4R4R4_NHWC_KYXC_NHWK_HPP #define CK_TRANSFORM_FORWARD_CONVOLUTION3D_INTO_GEMM_V4R4R4_NHWC_KYXC_NHWK_HPP diff --git a/include/ck/stream_config.hpp b/include/ck/stream_config.hpp index 70ca34555a01436c79ba244ac03572bb4e9520b4..505a602b240428bcd1f0f81018fef0c4716c80b7 100644 --- a/include/ck/stream_config.hpp +++ b/include/ck/stream_config.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/include/ck/tensor/static_tensor.hpp b/include/ck/tensor/static_tensor.hpp index fee679f91060aca623ab498362de78fd8118fee2..d719ef9760d79297600d7524167eba78cd137831 100644 --- a/include/ck/tensor/static_tensor.hpp +++ b/include/ck/tensor/static_tensor.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #ifndef CK_STATIC_TENSOR_HPP #define CK_STATIC_TENSOR_HPP diff --git a/include/ck/tensor_description/cluster_descriptor.hpp b/include/ck/tensor_description/cluster_descriptor.hpp index 0c9ea2ff2a0d73b793008a954eaf7293b33ade08..2dfcad8e042e548d15b4bd963fe615e39c04eff1 100644 --- a/include/ck/tensor_description/cluster_descriptor.hpp +++ b/include/ck/tensor_description/cluster_descriptor.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/include/ck/tensor_description/multi_index_transform.hpp b/include/ck/tensor_description/multi_index_transform.hpp index 4e4d7593e9083f321f1cabe65faaee3c14259666..ae3139ce78c8b3b881ee36602011dd895532bd83 100644 --- a/include/ck/tensor_description/multi_index_transform.hpp +++ b/include/ck/tensor_description/multi_index_transform.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #pragma once @@ -1042,13 +1042,13 @@ struct Merge_v2_magic_division using UpLengths = decltype(make_tuple(container_reduce(LowLengths{}, math::multiplies{}, Number<1>{}))); - using LowLengthsMagicDivisorMultipiler = decltype( - generate_tuple(lambda_merge_generate_MagicDivision_calculate_magic_multiplier{}, - Number{})); + using LowLengthsMagicDivisorMultipiler = decltype(generate_tuple( + lambda_merge_generate_MagicDivision_calculate_magic_multiplier{}, + Number{})); - using LowLengthsMagicDivisorShift = decltype( - generate_tuple(lambda_merge_generate_MagicDivision_calculate_magic_shift{}, - Number{})); + using LowLengthsMagicDivisorShift = decltype(generate_tuple( + lambda_merge_generate_MagicDivision_calculate_magic_shift{}, + Number{})); LowLengths low_lengths_; LowLengthsMagicDivisorMultipiler low_lengths_magic_divisor_multiplier_; @@ -1201,9 +1201,9 @@ struct Merge_v2r2_magic_division lambda_merge_generate_MagicDivision_calculate_magic_multiplier{}, Number{})); - using LowLengthsScanMagicDivisorShift = decltype( - generate_tuple(lambda_merge_generate_MagicDivision_calculate_magic_shift{}, - Number{})); + using LowLengthsScanMagicDivisorShift = decltype(generate_tuple( + lambda_merge_generate_MagicDivision_calculate_magic_shift{}, + Number{})); LowLengths low_lengths_; LowLengthsScan low_lengths_scan_; diff --git a/include/ck/tensor_description/multi_index_transform_helper.hpp b/include/ck/tensor_description/multi_index_transform_helper.hpp index 044a90370095eb53b94b5c2fba81abdbeae82c00..af0a8a34d0e48bde494dacfff5bbb17eda5eb479 100644 --- a/include/ck/tensor_description/multi_index_transform_helper.hpp +++ b/include/ck/tensor_description/multi_index_transform_helper.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/include/ck/tensor_description/tensor_adaptor.hpp b/include/ck/tensor_description/tensor_adaptor.hpp index d42e0a6ff08f60eb1cf6e0f241f93e528bf3514b..3ffac32469a8d97d678bc6b5fe62a7cc5e0b24ab 100644 --- a/include/ck/tensor_description/tensor_adaptor.hpp +++ b/include/ck/tensor_description/tensor_adaptor.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/include/ck/tensor_description/tensor_descriptor.hpp b/include/ck/tensor_description/tensor_descriptor.hpp index f07d5b1733d9cc96dfac9f18cbbe30510760cabc..f1df2eedd466c81e4b7938c4808075869e6309fd 100644 --- a/include/ck/tensor_description/tensor_descriptor.hpp +++ b/include/ck/tensor_description/tensor_descriptor.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/include/ck/tensor_description/tensor_descriptor_helper.hpp b/include/ck/tensor_description/tensor_descriptor_helper.hpp index 461aae72cf7b1879c90e473adb3814e1c4875b52..f3ac041bf9b8dc03802d170fbd7bf08ce6ab9cb1 100644 --- a/include/ck/tensor_description/tensor_descriptor_helper.hpp +++ b/include/ck/tensor_description/tensor_descriptor_helper.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/include/ck/tensor_description/tensor_space_filling_curve.hpp b/include/ck/tensor_description/tensor_space_filling_curve.hpp index 17c9100b9fd76418d562c5c175ad24062d8c3415..9a326092d2e0fd8392ec42a8c0a82b4167076373 100644 --- a/include/ck/tensor_description/tensor_space_filling_curve.hpp +++ b/include/ck/tensor_description/tensor_space_filling_curve.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/include/ck/tensor_operation/gpu/block/blockwise_gemm_dl_v2r3.hpp b/include/ck/tensor_operation/gpu/block/blockwise_gemm_dl_v2r3.hpp index 8b1b7be11ef7b3cd19f11bd400b8a96146921bf7..f23404a1d7ba81f9c65bf111da51701f2814d4bc 100644 --- a/include/ck/tensor_operation/gpu/block/blockwise_gemm_dl_v2r3.hpp +++ b/include/ck/tensor_operation/gpu/block/blockwise_gemm_dl_v2r3.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #pragma once @@ -11,7 +11,7 @@ namespace ck { // C[BM0, BM1, BN0, BN1] += transpose(A[K, BM0, BM1]) * B[K, BN0, BN1] -// A and B are visable to the whole block, C is distributed among each thread +// A and B are visible to the whole block, C is distributed among each thread // Assume: // 1. A: // 1. ABlockDesc_BK0_BM_BK1 is known at compile-time diff --git a/include/ck/tensor_operation/gpu/block/blockwise_gemm_dlops_v2r2.hpp b/include/ck/tensor_operation/gpu/block/blockwise_gemm_dlops_v2r2.hpp index 33120bd86ff01da47f7b80c593bc966785cb3711..b0143366c1de5f8e6ec52cceefd3c20e4e77cbb7 100644 --- a/include/ck/tensor_operation/gpu/block/blockwise_gemm_dlops_v2r2.hpp +++ b/include/ck/tensor_operation/gpu/block/blockwise_gemm_dlops_v2r2.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #ifndef CK_BLOCKWISE_GEMM_DLOPS_V2R2_HPP #define CK_BLOCKWISE_GEMM_DLOPS_V2R2_HPP diff --git a/include/ck/tensor_operation/gpu/block/blockwise_gemm_dlops_v3.hpp b/include/ck/tensor_operation/gpu/block/blockwise_gemm_dlops_v3.hpp index f45655721fe4de11b50b91e0dc5d22790ff73bea..0d092da5168c5d5629eecc2aac7cbdd53e210277 100644 --- a/include/ck/tensor_operation/gpu/block/blockwise_gemm_dlops_v3.hpp +++ b/include/ck/tensor_operation/gpu/block/blockwise_gemm_dlops_v3.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #ifndef CK_BLOCKWISE_GEMM_DLOPS_V3_HPP #define CK_BLOCKWISE_GEMM_DLOPS_V3_HPP diff --git a/include/ck/tensor_operation/gpu/block/blockwise_gemm_dpp.hpp b/include/ck/tensor_operation/gpu/block/blockwise_gemm_dpp.hpp new file mode 100644 index 0000000000000000000000000000000000000000..d62ed4b15dd3a6ea078ef7ed0b70e818a30af611 --- /dev/null +++ b/include/ck/tensor_operation/gpu/block/blockwise_gemm_dpp.hpp @@ -0,0 +1,348 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck/utility/common_header.hpp" +#include "ck/tensor_description/tensor_adaptor.hpp" +#include "ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp" +#include "ck/tensor_operation/gpu/warp/dpp_gemm.hpp" + +namespace ck { + +/** + * Blockwise GEMM that uses DPP instruction modifier to limit the amount of data loaded for each + * thread by sharing the data between threads in a lanegroup. + * + * In every iteration, each wave calculates a C tile of size `MPerDpp` * `NPerDpp`, there are + * `MRepeat` iterations for `M` dimension and `NRepeat` for `N` one. + * In total, the algorithm runs using + * `MPerBlock / (MRepeat * MPerDpp) * NPerBlock / (NRepeat * NPerDpp)` waves. + */ +template +struct BlockwiseGemmDpp_ak0mak1_bk0nbk1_m0n0m1n1m2n2 +{ + static constexpr auto I0 = Number<0>{}; + static constexpr auto I1 = Number<1>{}; + static constexpr auto I2 = Number<2>{}; + static constexpr auto I3 = Number<3>{}; + + using ThisThreadBlock = ThisThreadBlock; + + static constexpr index_t WaveSize = get_warp_size(); + + static constexpr index_t MPerBlock = AK0MK1BlockDesc{}.GetLength(I1); + static constexpr index_t NPerBlock = BK0NK1BlockDesc{}.GetLength(I1); + static constexpr index_t KPerBlock = + BK0NK1BlockDesc{}.GetLength(I0) * BK0NK1BlockDesc{}.GetLength(I2); + + static constexpr index_t A_K0 = AK0MK1BlockDesc{}.GetLength(I0); + static constexpr index_t B_K0 = BK0NK1BlockDesc{}.GetLength(I0); + static constexpr index_t A_K1 = AK0MK1BlockDesc{}.GetLength(I2); + static constexpr index_t B_K1 = BK0NK1BlockDesc{}.GetLength(I2); + + static constexpr auto dpp_gemm = DppGemm{}; + + static constexpr index_t KPerThread = KPerBlock / dpp_gemm.K0PerDpp; + + static constexpr index_t MWaves = MPerBlock / (MRepeat * MPerDpp); + static constexpr index_t NWaves = NPerBlock / (NRepeat * NPerDpp); + + StaticBufferTupleOfVector + c_thread_buf_; + + __host__ __device__ constexpr auto& GetCThreadBuffer() { return c_thread_buf_; } + + __device__ static auto GetWaveIdx() + { + const index_t thread_id = ThisThreadBlock::GetThreadId(); + + constexpr auto threadid_to_wave_idx_adaptor = make_single_stage_tensor_adaptor( + make_tuple(make_merge_transform(make_tuple(MWaves, NWaves, WaveSize))), + make_tuple(Sequence<0, 1, 2>{}), + make_tuple(Sequence<0>{})); + + return threadid_to_wave_idx_adaptor.CalculateBottomIndex(make_multi_index(thread_id)); + } + + __device__ static auto CalculateAThreadOriginDataIndex_M0_M1_M2_K() + { + const auto wave_idx = GetWaveIdx(); + const auto waveId_m = wave_idx[I0]; + const auto dpp_a_idx = dpp_gemm.CalculateAThreadOriginDataIndex_K_M(); + const auto dpp_a_idx_k = dpp_a_idx[I0]; + const auto dpp_a_idx_m = dpp_a_idx[I1]; + return make_tuple(0, waveId_m, dpp_a_idx_m, KPerThread * dpp_a_idx_k); + } + + __device__ static auto CalculateBThreadOriginDataIndex_N0_N1_N2_K() + { + const auto wave_idx = GetWaveIdx(); + const auto waveId_n = wave_idx[I1]; + const auto dpp_b_idx = dpp_gemm.CalculateBThreadOriginDataIndex_K_N(); + const auto dpp_b_idx_k = dpp_b_idx[I0]; + const auto dpp_b_idx_n = dpp_b_idx[I1]; + return make_tuple(0, waveId_n, dpp_b_idx_n, KPerThread * dpp_b_idx_k); + } + + template + __device__ static auto CalculateCThreadOriginDataIndex(Number, Number) + { + const auto wave_idx = GetWaveIdx(); + const auto waveId_m = wave_idx[I0]; + const auto waveId_n = wave_idx[I1]; + + const auto blk_idx = dpp_gemm.GetBeginOfThreadBlk(); + const auto blk_m_offset = blk_idx[I0]; + const auto blk_n_offset = blk_idx[I1]; + + constexpr auto mrepeat_mwave_MPerDpp_to_m_adaptor = make_single_stage_tensor_adaptor( + make_tuple(make_unmerge_transform(make_tuple(MRepeat, MWaves, MPerDpp))), + make_tuple(Sequence<0>{}), + make_tuple(Sequence<0, 1, 2>{})); + + constexpr auto nrepeat_nwave_NPerDpp_to_n_adaptor = make_single_stage_tensor_adaptor( + make_tuple(make_unmerge_transform(make_tuple(NRepeat, NWaves, NPerDpp))), + make_tuple(Sequence<0>{}), + make_tuple(Sequence<0, 1, 2>{})); + + const index_t c_thread_m = mrepeat_mwave_MPerDpp_to_m_adaptor.CalculateBottomIndex( + make_tuple(m0, waveId_m, blk_m_offset))[I0]; + const index_t c_thread_n = nrepeat_nwave_NPerDpp_to_n_adaptor.CalculateBottomIndex( + make_tuple(n0, waveId_n, blk_n_offset))[I0]; + + return make_tuple(c_thread_m, c_thread_n); + } + + __host__ __device__ BlockwiseGemmDpp_ak0mak1_bk0nbk1_m0n0m1n1m2n2() + { + static_assert(AK0MK1BlockDesc::IsKnownAtCompileTime() && + BK0NK1BlockDesc::IsKnownAtCompileTime(), + "Wrong! Block descriptors should be known at the time of compilation."); + +#if defined(__HIP_DEVICE_COMPILE__) + // Host wave size can be different than the device one and this assert could fail for host, + // but it does matter only for device. + static_assert(ThisThreadBlock::GetNumOfThread() == MWaves * NWaves * WaveSize, + "ThisThreadBlock::GetNumOfThread() != MWaves * NWaves * WaveSize\n"); +#endif + + static_assert(MPerBlock % (MPerDpp * MRepeat) == 0, + "Invalid parameters. MPerBlock must be divisible by MPerDpp * MRepeat."); + static_assert(NPerBlock % (NPerDpp * NRepeat) == 0, + "Invalid parameters. NPerBlock must be divisible by NPerDpp * NRepeat."); + } + + __host__ __device__ static constexpr auto GetCThreadDescriptor_M0_N0_M1_N1_M2_N2() + { + constexpr auto c_m_n_tblk_lens = dpp_gemm.GetCMNThreadBlkLengths(); + constexpr auto M = c_m_n_tblk_lens[I0]; + constexpr auto N = c_m_n_tblk_lens[I1]; + + return make_naive_tensor_descriptor_packed( + make_tuple(Number{}, Number{}, I1, I1, M, N)); + } + + __host__ __device__ static constexpr auto GetCThreadDescriptor_G_M0_N0_M1_N1_M2_N2() + { + constexpr auto c_m_n_tblk_lens = dpp_gemm.GetCMNThreadBlkLengths(); + constexpr auto M = c_m_n_tblk_lens[I0]; + constexpr auto N = c_m_n_tblk_lens[I1]; + + return make_naive_tensor_descriptor_packed( + make_tuple(I1, Number{}, Number{}, I1, I1, M, N)); + } + + __host__ __device__ static constexpr auto GetCBlockDescriptor_M0_N0_M1_N1_M2_N2() + { + constexpr auto c_block_desc_m0_n0_m1_n1_m2_n2 = + make_naive_tensor_descriptor_packed(make_tuple(Number{}, + Number{}, + Number{}, + Number{}, + Number{}, + Number{})); + + return c_block_desc_m0_n0_m1_n1_m2_n2; + } + + __host__ __device__ static constexpr auto GetCBlockDescriptor_G_M0_N0_M1_N1_M2_N2() + { + constexpr auto c_block_desc_g_m0_n0_m1_n1_m2_n2 = + make_naive_tensor_descriptor_packed(make_tuple(I1, + Number{}, + Number{}, + Number{}, + Number{}, + Number{}, + Number{})); + return c_block_desc_g_m0_n0_m1_n1_m2_n2; + } + + template + __host__ __device__ static constexpr auto + MakeCGridDescriptor_M0_N0_M1_N1_M2_N2(const CGridDesc_M_N& c_grid_desc_m_n) + { + const auto M = c_grid_desc_m_n.GetLength(I0); + const auto N = c_grid_desc_m_n.GetLength(I1); + + const auto c_grid_desc_m0_n0_m1_n1_m2_n2 = transform_tensor_descriptor( + c_grid_desc_m_n, + make_tuple(make_unmerge_transform(make_tuple(M / (MWaves * MPerDpp), MWaves, MPerDpp)), + make_unmerge_transform(make_tuple(N / (NWaves * NPerDpp), NWaves, NPerDpp))), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0, 2, 4>{}, Sequence<1, 3, 5>{})); + + return c_grid_desc_m0_n0_m1_n1_m2_n2; + } + + template + __host__ __device__ static constexpr auto + MakeCGridDescriptor_G_M0_N0_M1_N1_M2_N2(const CGridDesc_G_M_N& c_grid_desc_g_m_n) + { + const auto G = c_grid_desc_g_m_n.GetLength(I0); + const auto M = c_grid_desc_g_m_n.GetLength(I1); + const auto N = c_grid_desc_g_m_n.GetLength(I2); + + const auto c_grid_desc_g_m0_n0_m1_n1_m2_n2 = transform_tensor_descriptor( + c_grid_desc_g_m_n, + make_tuple(make_pass_through_transform(G), + make_unmerge_transform(make_tuple(M / (MWaves * MPerDpp), MWaves, MPerDpp)), + make_unmerge_transform(make_tuple(N / (NWaves * NPerDpp), NWaves, NPerDpp))), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}), + make_tuple(Sequence<0>{}, Sequence<1, 3, 5>{}, Sequence<2, 4, 6>{})); + + return c_grid_desc_g_m0_n0_m1_n1_m2_n2; + } + + __host__ __device__ static constexpr auto MakeABlockDescriptor_M0_M1_M2_K() + { + return transform_tensor_descriptor( + AK0MK1BlockDesc{}, + make_tuple( + make_merge_transform_v3_division_mod(make_tuple(Number{}, Number{})), + make_unmerge_transform( + make_tuple(Number{}, Number{}, Number{}))), + make_tuple(Sequence<0, 2>{}, Sequence<1>{}), + make_tuple(Sequence<3>{}, Sequence<0, 1, 2>{})); + } + + __host__ __device__ static constexpr auto MakeBBlockDescriptor_N0_N1_N2_K() + { + return transform_tensor_descriptor( + BK0NK1BlockDesc{}, + make_tuple( + make_merge_transform_v3_division_mod(make_tuple(Number{}, Number{})), + make_unmerge_transform( + make_tuple(Number{}, Number{}, Number{}))), + make_tuple(Sequence<0, 2>{}, Sequence<1>{}), + make_tuple(Sequence<3>{}, Sequence<0, 1, 2>{})); + } + + static constexpr auto a_block_desc_m0_m1_m2_k = MakeABlockDescriptor_M0_M1_M2_K(); + static constexpr auto b_block_desc_n0_n1_n2_k = MakeBBlockDescriptor_N0_N1_N2_K(); + + template + __device__ void Run(const ABlockBuffer& a_block_buf, + const BBlockBuffer& b_block_buf, + CThreadBuffer& c_thread_buf) const + { + auto a_thread_buf = make_static_buffer( + a_thread_desc_.GetElementSpaceSize()); + auto b_thread_buf = make_static_buffer( + b_thread_desc_.GetElementSpaceSize()); + + static_for<0, MRepeat, 1>{}([&](auto m0) { + // read A + a_thread_copy_.Run(a_block_desc_m0_m1_m2_k, + make_tuple(m0, I0, I0, I0), + a_block_buf, + a_thread_desc_, + make_tuple(I0, I0, I0, I0), + a_thread_buf); + + static_for<0, NRepeat, 1>{}([&](auto n0) { + // read B + b_thread_copy_.Run(b_block_desc_n0_n1_n2_k, + make_tuple(n0, I0, I0, I0), + b_block_buf, + b_thread_desc_, + make_tuple(I0, I0, I0, I0), + b_thread_buf); + + static_for<0, KPerThread, KPack>{}([&](auto k) { + vector_type a_thread_vec; + vector_type b_thread_vec; + + static_for<0, KPack, 1>{}([&](auto i) { + a_thread_vec.template AsType()(i) = a_thread_buf + [Number{}]; + b_thread_vec.template AsType()(i) = b_thread_buf + [Number{}]; + }); + + using dpp_input_type = + typename vector_type::type; + + constexpr index_t c_offset = + c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0)); + + dpp_gemm.template Run(a_thread_vec.template AsType(), + b_thread_vec.template AsType(), + c_thread_buf.GetVectorTypeReference(Number{})); + }); + }); + }); + } + + protected: + // A[M0, M1, M2, KPerThread] + static constexpr auto a_thread_desc_ = + make_naive_tensor_descriptor_packed(make_tuple(I1, I1, I1, Number{})); + + // B[N0, N1, N2, KPerThread] + static constexpr auto b_thread_desc_ = + make_naive_tensor_descriptor_packed(make_tuple(I1, I1, I1, Number{})); + + // C[M, N, NumRegDpp] + static constexpr auto c_thread_desc_ = make_naive_tensor_descriptor_packed( + make_tuple(Number{}, Number{}, dpp_gemm.GetRegSizePerDpp())); + + using AThreadCopy = ThreadwiseTensorSliceTransfer_v4, + Sequence<0, 1, 2, 3>, + 3, + A_K1, + A_K1>; + + using BThreadCopy = ThreadwiseTensorSliceTransfer_v4, + Sequence<0, 1, 2, 3>, + 3, + B_K1, + B_K1>; + + AThreadCopy a_thread_copy_{CalculateAThreadOriginDataIndex_M0_M1_M2_K()}; + BThreadCopy b_thread_copy_{CalculateBThreadOriginDataIndex_N0_N1_N2_K()}; +}; + +} // namespace ck diff --git a/include/ck/tensor_operation/gpu/block/blockwise_gemm_wmma.hpp b/include/ck/tensor_operation/gpu/block/blockwise_gemm_wmma.hpp index c09b8f7d1a63c96e9ce8a921ea8129d0951ab89b..b3d45f3d0c843c5e72a37b4a9a7617f2f1abc16a 100644 --- a/include/ck/tensor_operation/gpu/block/blockwise_gemm_wmma.hpp +++ b/include/ck/tensor_operation/gpu/block/blockwise_gemm_wmma.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/include/ck/tensor_operation/gpu/block/blockwise_gemm_xdlops.hpp b/include/ck/tensor_operation/gpu/block/blockwise_gemm_xdlops.hpp index 5328dfde9bc09039af03892467d94545002c56fd..1fee9c3225e7be667588baa87b4e3d80029f511c 100644 --- a/include/ck/tensor_operation/gpu/block/blockwise_gemm_xdlops.hpp +++ b/include/ck/tensor_operation/gpu/block/blockwise_gemm_xdlops.hpp @@ -1,30 +1,16 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #pragma once #include "ck/utility/common_header.hpp" +#include "ck/utility/loop_scheduler.hpp" #include "ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp" #include "ck/tensor_operation/gpu/warp/xdlops_gemm.hpp" #include "ck/tensor_description/tensor_adaptor.hpp" namespace ck { -enum struct LoopScheduler -{ - Default, - Interwave, -}; - -constexpr LoopScheduler make_default_loop_scheduler() -{ -#if CK_EXPERIMENTAL_DEFAULT_TO_INTER_WAVE_SCHEDULING - return LoopScheduler::Interwave; -#else - return LoopScheduler::Default; -#endif // if CK_EXPERIMENTAL_DEFAULT_TO_INTER_WAVE_SCHEDULING -} - template __host__ __device__ static constexpr auto MakeGemmMmaTileDescriptor_MN0_MN1_MN2_K(const TileDesc_K0_MN_K1&) diff --git a/include/ck/tensor_operation/gpu/block/blockwise_gemm_xdlops_skip_b_lds.hpp b/include/ck/tensor_operation/gpu/block/blockwise_gemm_xdlops_skip_b_lds.hpp index aa814ab00939d0c36866af01b6cf2057dcbd5121..8ae1ba3f34c10d8359b470ed1172f1b70a7fa8b5 100644 --- a/include/ck/tensor_operation/gpu/block/blockwise_gemm_xdlops_skip_b_lds.hpp +++ b/include/ck/tensor_operation/gpu/block/blockwise_gemm_xdlops_skip_b_lds.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/include/ck/tensor_operation/gpu/block/blockwise_softmax.hpp b/include/ck/tensor_operation/gpu/block/blockwise_softmax.hpp index 7e62a822a8f042d238905295b7f88c1d9bdd88f8..2fb7242708da58828b1d6f5c6b78ea472783814f 100644 --- a/include/ck/tensor_operation/gpu/block/blockwise_softmax.hpp +++ b/include/ck/tensor_operation/gpu/block/blockwise_softmax.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #pragma once @@ -35,8 +35,8 @@ struct BlockwiseSoftmax static constexpr index_t MRepeat = ThreadSliceDesc_M_K{}.GetLength(I0); static constexpr index_t KRepeat = ThreadSliceDesc_M_K{}.GetLength(I1); - using ThreadSliceDesc_M = decltype( - make_naive_tensor_descriptor_packed(make_tuple(ThreadSliceDesc_M_K{}.GetLength(I0)))); + using ThreadSliceDesc_M = decltype(make_naive_tensor_descriptor_packed( + make_tuple(ThreadSliceDesc_M_K{}.GetLength(I0)))); using ThreadwiseMaxReduce = typename conditional< IgnoreNaN, diff --git a/include/ck/tensor_operation/gpu/block/blockwise_tensor_slice_transfer_v5r1.hpp b/include/ck/tensor_operation/gpu/block/blockwise_tensor_slice_transfer_v5r1.hpp index 03e4d42d3a1f9eb1b180c95368b905e619e67110..d8da134a3415a6976b27d0a4fdd7f13798d0245b 100644 --- a/include/ck/tensor_operation/gpu/block/blockwise_tensor_slice_transfer_v5r1.hpp +++ b/include/ck/tensor_operation/gpu/block/blockwise_tensor_slice_transfer_v5r1.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/include/ck/tensor_operation/gpu/block/blockwise_welford.hpp b/include/ck/tensor_operation/gpu/block/blockwise_welford.hpp index 316508651e4bd4dad2e47edf020a4145176fa753..820a08fc482b44a3509188db0f98aecabb692a82 100644 --- a/include/ck/tensor_operation/gpu/block/blockwise_welford.hpp +++ b/include/ck/tensor_operation/gpu/block/blockwise_welford.hpp @@ -1,10 +1,10 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #pragma once #include "ck/tensor_description/cluster_descriptor.hpp" -#include "ck/utility/reduction_common.hpp" +#include "ck/utility/get_shift.hpp" namespace ck { @@ -35,10 +35,11 @@ struct BlockwiseWelford static constexpr auto thread_cluster_desc = make_cluster_descriptor(ThreadClusterLengths_M_K{}, ThreadClusterArrangeOrder{}); + template __device__ static inline void - Merge(T& mean_a, T& var_a, int& count_a, T mean_b, T var_b, int count_b) + Merge(T& mean_a, T& var_a, CountDataType& count_a, T mean_b, T var_b, CountDataType count_b) { - int count = count_a + count_b; + CountDataType count = count_a + count_b; T count_b_over_count = count == 0 ? type_convert(0) : type_convert(count_b) / count; T delta = mean_b - mean_a; mean_a += delta * count_b_over_count; @@ -46,11 +47,12 @@ struct BlockwiseWelford count_a = count; } - __device__ static void Run(T& mean_value, T& var_value, int& count) + template + __device__ static void Run(T& mean_value, T& var_value, CountDataType& count) { __shared__ T mean_block_buf[BlockSize]; __shared__ T var_block_buf[BlockSize]; - __shared__ int count_block_buf[BlockSize]; + __shared__ CountDataType count_block_buf[BlockSize]; constexpr auto cluster_len_shift = get_shift(); @@ -76,13 +78,13 @@ struct BlockwiseWelford index_t offset2 = block_buf_desc_m_k.CalculateOffset(thread_cluster_idx + make_tuple(0, indOffset)); - T mean1 = mean_block_buf[offset1]; - T var1 = var_block_buf[offset1]; - int count1 = count_block_buf[offset1]; + T mean1 = mean_block_buf[offset1]; + T var1 = var_block_buf[offset1]; + CountDataType count1 = count_block_buf[offset1]; - T mean2 = mean_block_buf[offset2]; - T var2 = var_block_buf[offset2]; - int count2 = count_block_buf[offset2]; + T mean2 = mean_block_buf[offset2]; + T var2 = var_block_buf[offset2]; + CountDataType count2 = count_block_buf[offset2]; Merge(mean1, var1, count1, mean2, var2, count2); diff --git a/include/ck/tensor_operation/gpu/block/reduction_functions_blockwise.hpp b/include/ck/tensor_operation/gpu/block/reduction_functions_blockwise.hpp index 2163ad32383d4f940ab4cf3b041ab2d5c61456cc..82667e235238170806846940e5d153ea20992b73 100644 --- a/include/ck/tensor_operation/gpu/block/reduction_functions_blockwise.hpp +++ b/include/ck/tensor_operation/gpu/block/reduction_functions_blockwise.hpp @@ -1,10 +1,10 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #pragma once #include "ck/tensor_description/cluster_descriptor.hpp" -#include "ck/utility/reduction_common.hpp" +#include "ck/utility/get_shift.hpp" #include "ck/utility/reduction_functions_accumulate.hpp" namespace ck { diff --git a/include/ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v4r1.hpp b/include/ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v4r1.hpp index 04ad75bd7de4a9657c2aaaf21b40205c3ff9c7d8..2c5fbc3937c5e3782156bf81f46b4fb4b1f1db46 100644 --- a/include/ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v4r1.hpp +++ b/include/ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v4r1.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #pragma once @@ -94,6 +94,21 @@ struct ThreadGroupTensorSliceTransfer_v4r1 } } + __device__ void SetSrcSliceOrigin(const SrcDesc& src_desc, const Index& src_block_slice_origin) + { + if(ThreadGroup::GetNumOfThread() == thread_cluster_desc_.GetElementSize() or + ThreadGroup::GetThreadId() < thread_cluster_desc_.GetElementSize()) + { + const auto thread_cluster_idx = thread_cluster_desc_.CalculateBottomIndex( + make_multi_index(ThreadGroup::GetThreadId())); + + const auto thread_data_idx_begin = thread_cluster_idx * thread_slice_lengths; + + threadwise_transfer_.SetSrcSliceOrigin(src_desc, + src_block_slice_origin + thread_data_idx_begin); + } + } + template __device__ void RunRead(const SrcDesc& src_desc, const SrcBuffer& src_buf, diff --git a/include/ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v6r1.hpp b/include/ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v6r1.hpp index 5c47a49b38b6a30eebc9190cc338d4dbc0bc8524..905a59f56e3b42bfb4b2bdf8652a3e81b2872076 100644 --- a/include/ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v6r1.hpp +++ b/include/ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v6r1.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/include/ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v6r1r2.hpp b/include/ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v6r1r2.hpp new file mode 100644 index 0000000000000000000000000000000000000000..83cb9fb5de4e047a71bbcaf47f67df925b2dac74 --- /dev/null +++ b/include/ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v6r1r2.hpp @@ -0,0 +1,164 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck/utility/common_header.hpp" +#include "ck/tensor_description/tensor_descriptor.hpp" +#include "ck/tensor_description/tensor_descriptor_helper.hpp" +#include "ck/tensor_description/cluster_descriptor.hpp" +#include "ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer_v6r1r2.hpp" + +namespace ck { + +// this version does following things to avoid scratch memory issue +// 1. Use StaticallyIndexedArray instead of C array for thread buffer +// 2. ThreadwiseTensorSliceTransfer_v3 does not keep reference to tensor descriptor +// 3. ThreadwiseTensorSliceTransfer_v3::Run() does not construct new tensor coordinate +template +struct ThreadGroupTensorSliceTransfer_v6r1r2 +{ + static constexpr index_t nDim = remove_reference_t::GetNumOfDimension(); + + static constexpr auto thread_slice_lengths = SliceLengths{} / ThreadClusterLengths{}; + + using Index = MultiIndex; + + __device__ constexpr ThreadGroupTensorSliceTransfer_v6r1r2( + const SrcDesc& src_desc, + const Index& src_block_slice_origin, + const DstDesc& dst_desc, + const Index& dst_block_slice_origin, + const ElementwiseOperation& element_op) + : threadwise_transfer_(src_desc, + make_zero_multi_index(), + dst_desc, + make_zero_multi_index(), + element_op) + + { + static_assert(nDim == remove_cvref_t::GetNumOfDimension() && + nDim == remove_cvref_t::GetNumOfDimension() && + nDim == ThreadClusterLengths::Size() && + nDim == ThreadClusterArrangeOrder::Size() && + nDim == DimAccessOrder::Size(), + "wrong! nDim not consistent"); + + static_assert( + is_same{}, + "wrong! threads should be mapped to cover entire slicing window"); + + static_assert(ThreadGroup::GetNumOfThread() >= thread_cluster_desc_.GetElementSize(), + "wrong! ThreadGroup::GetNumOfThread() too small"); + + if(ThreadGroup::GetNumOfThread() == thread_cluster_desc_.GetElementSize() or + ThreadGroup::GetThreadId() < thread_cluster_desc_.GetElementSize()) + { + const auto thread_cluster_idx = thread_cluster_desc_.CalculateBottomIndex( + make_multi_index(ThreadGroup::GetThreadId())); + + const auto thread_data_idx_begin = thread_cluster_idx * thread_slice_lengths; + + threadwise_transfer_.SetSrcSliceOrigin(src_desc, + src_block_slice_origin + thread_data_idx_begin); + threadwise_transfer_.SetDstSliceOrigin(dst_desc, + dst_block_slice_origin + thread_data_idx_begin); + } + } + + template + __device__ void Run(const SrcDesc& src_desc, + const SrcBuffer& src_buf, + const DstDesc& dst_desc, + DstBuffer& dst_buf) + { + if(ThreadGroup::GetNumOfThread() == thread_cluster_desc_.GetElementSize() or + ThreadGroup::GetThreadId() < thread_cluster_desc_.GetElementSize()) + { + threadwise_transfer_.template Run( + src_desc, src_buf, dst_desc, dst_buf); + } + } + + __device__ void MoveSrcSliceWindow(const SrcDesc& src_desc, const Index& step) + { + if(ThreadGroup::GetNumOfThread() == thread_cluster_desc_.GetElementSize() or + ThreadGroup::GetThreadId() < thread_cluster_desc_.GetElementSize()) + { + threadwise_transfer_.MoveSrcSliceWindow(src_desc, step); + } + } + + __device__ void MoveDstSliceWindow(const DstDesc& dst_desc, const Index& step) + { + if(ThreadGroup::GetNumOfThread() == thread_cluster_desc_.GetElementSize() or + ThreadGroup::GetThreadId() < thread_cluster_desc_.GetElementSize()) + { + threadwise_transfer_.MoveDstSliceWindow(dst_desc, step); + } + } + + __device__ void SetSrcSliceOrigin(const SrcDesc& src_desc, const Index& src_block_slice_origin) + { + if(ThreadGroup::GetNumOfThread() == thread_cluster_desc_.GetElementSize() or + ThreadGroup::GetThreadId() < thread_cluster_desc_.GetElementSize()) + { + const auto thread_cluster_idx = thread_cluster_desc_.CalculateBottomIndex( + make_multi_index(ThreadGroup::GetThreadId())); + + const auto thread_data_idx_begin = thread_cluster_idx * thread_slice_lengths; + + threadwise_transfer_.SetSrcSliceOrigin(src_desc, + src_block_slice_origin + thread_data_idx_begin); + } + } + + __device__ void SetDstSliceOrigin(const DstDesc& dst_desc, const Index& dst_block_slice_origin) + { + if(ThreadGroup::GetNumOfThread() == thread_cluster_desc_.GetElementSize() or + ThreadGroup::GetThreadId() < thread_cluster_desc_.GetElementSize()) + { + const auto thread_cluster_idx = thread_cluster_desc_.CalculateBottomIndex( + make_multi_index(ThreadGroup::GetThreadId())); + + const auto thread_data_idx_begin = thread_cluster_idx * thread_slice_lengths; + + threadwise_transfer_.SetDstSliceOrigin(dst_desc, + dst_block_slice_origin + thread_data_idx_begin); + } + } + + private: + static constexpr auto thread_cluster_desc_ = + make_cluster_descriptor(ThreadClusterLengths{}, ThreadClusterArrangeOrder{}); + + using ThreadwiseTransfer = + ThreadwiseTensorSliceTransfer_v6r1r2; + + ThreadwiseTransfer threadwise_transfer_; +}; + +} // namespace ck diff --git a/include/ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v6r2.hpp b/include/ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v6r2.hpp index aa33fc083f15cfd8904c8fa086a866dbc7817e7c..17110c8358dbaf7e602e7f9822968e08f80e14ef 100644 --- a/include/ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v6r2.hpp +++ b/include/ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v6r2.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/include/ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v6r3.hpp b/include/ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v6r3.hpp index eb5f589a4ada976c9be4a5001fb6fc288c7e9a43..9a5317dd126a818bc88043408427f0f105cb7f6e 100644 --- a/include/ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v6r3.hpp +++ b/include/ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v6r3.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/include/ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v7.hpp b/include/ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v7.hpp index 3bd7806389b59542b191e61f5369aebf7300fc6a..993d90e356def70ba54e357889277188b4f21971 100644 --- a/include/ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v7.hpp +++ b/include/ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v7.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/include/ck/tensor_operation/gpu/device/convolution_backward_data_specialization.hpp b/include/ck/tensor_operation/gpu/device/convolution_backward_data_specialization.hpp index a4a29f5d5edc3b840fc7489d5f706231990cad70..cab5e213651dfc0756e03a82ea4389042b1539ae 100644 --- a/include/ck/tensor_operation/gpu/device/convolution_backward_data_specialization.hpp +++ b/include/ck/tensor_operation/gpu/device/convolution_backward_data_specialization.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #pragma once @@ -19,8 +19,7 @@ getConvBackwardDataSpecializationString(const ConvolutionBackwardDataSpecializat switch(s) { case ConvolutionBackwardDataSpecialization::Default: return "Default"; - case ConvolutionBackwardDataSpecialization::Filter1x1Stride1Pad0: - return "FFilter1x1Stride1Pad0"; + case ConvolutionBackwardDataSpecialization::Filter1x1Stride1Pad0: return "Filter1x1Stride1Pad0"; default: return "Unrecognized specialization!"; } } diff --git a/include/ck/tensor_operation/gpu/device/convolution_backward_weight_specialization.hpp b/include/ck/tensor_operation/gpu/device/convolution_backward_weight_specialization.hpp index 20b2a152b9d276b0bd0f6e1aa5798b4817b013a9..01bb806789c3d9e50f706d7c88982b4a6624ff57 100644 --- a/include/ck/tensor_operation/gpu/device/convolution_backward_weight_specialization.hpp +++ b/include/ck/tensor_operation/gpu/device/convolution_backward_weight_specialization.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/include/ck/tensor_operation/gpu/device/convolution_forward_specialization.hpp b/include/ck/tensor_operation/gpu/device/convolution_forward_specialization.hpp index 953ff1e06ed07b24968c7f5c0842161ac66643ed..adfa1689c66509c8c194985365356740d4c90473 100644 --- a/include/ck/tensor_operation/gpu/device/convolution_forward_specialization.hpp +++ b/include/ck/tensor_operation/gpu/device/convolution_forward_specialization.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/include/ck/tensor_operation/gpu/device/device_avgpool_bwd.hpp b/include/ck/tensor_operation/gpu/device/device_avgpool_bwd.hpp new file mode 100644 index 0000000000000000000000000000000000000000..5a2b082c31aaffa401abc126b8770c3ee9f35686 --- /dev/null +++ b/include/ck/tensor_operation/gpu/device/device_avgpool_bwd.hpp @@ -0,0 +1,39 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include + +#include "ck/tensor_operation/gpu/device/device_base.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { + +template +struct DeviceAvgPoolBwd : public BaseOperator +{ + virtual std::unique_ptr + MakeArgumentPointer(const void* p_dout, + void* p_din, + std::vector dout_n_k_wos_lengths, + std::vector dout_n_k_wos_strides, + std::vector din_n_k_wos_length, + std::vector din_n_k_wos_strides, + std::vector window_k_c_xs_lengths, + std::vector window_strides, + std::vector window_dilations, + std::vector input_left_pads, + std::vector input_right_pads) = 0; + + virtual std::unique_ptr MakeInvokerPointer() = 0; +}; + +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/include/ck/tensor_operation/gpu/device/device_base.hpp b/include/ck/tensor_operation/gpu/device/device_base.hpp index 5946daf21ec169612c3207282007e05671882f09..198169011107fb0f236d4657c399ff0534ce2c98 100644 --- a/include/ck/tensor_operation/gpu/device/device_base.hpp +++ b/include/ck/tensor_operation/gpu/device/device_base.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/include/ck/tensor_operation/gpu/device/device_batched_contraction_multiple_d.hpp b/include/ck/tensor_operation/gpu/device/device_batched_contraction_multiple_d.hpp index 9fcd893c7a8c38d31f480b13eb8e7bfc97ec7fef..ee7af0117d17acc561e5943ac59611e188b6ba4c 100644 --- a/include/ck/tensor_operation/gpu/device/device_batched_contraction_multiple_d.hpp +++ b/include/ck/tensor_operation/gpu/device/device_batched_contraction_multiple_d.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/include/ck/tensor_operation/gpu/device/device_batched_gemm.hpp b/include/ck/tensor_operation/gpu/device/device_batched_gemm.hpp index e755913280f73e325daf1352324bab8d3df8a3a9..6cc2c7bb2f6c176f2d84fdda4be2140db5564360 100644 --- a/include/ck/tensor_operation/gpu/device/device_batched_gemm.hpp +++ b/include/ck/tensor_operation/gpu/device/device_batched_gemm.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/include/ck/tensor_operation/gpu/device/device_batched_gemm_gemm.hpp b/include/ck/tensor_operation/gpu/device/device_batched_gemm_gemm.hpp index af681127f30404f182475039add4443e46213398..91b4b6b91b6b6be5d08d88e5ca68c9b62fb02aea 100644 --- a/include/ck/tensor_operation/gpu/device/device_batched_gemm_gemm.hpp +++ b/include/ck/tensor_operation/gpu/device/device_batched_gemm_gemm.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/include/ck/tensor_operation/gpu/device/device_batched_gemm_multi_d.hpp b/include/ck/tensor_operation/gpu/device/device_batched_gemm_multi_d.hpp index 116e62c00907e0f72031552f46b8a3ea975a18f5..f18dc3290600e63edb869d93dcb1a903206e32ab 100644 --- a/include/ck/tensor_operation/gpu/device/device_batched_gemm_multi_d.hpp +++ b/include/ck/tensor_operation/gpu/device/device_batched_gemm_multi_d.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/include/ck/tensor_operation/gpu/device/device_batched_gemm_multiple_d_gemm_multiple_d.hpp b/include/ck/tensor_operation/gpu/device/device_batched_gemm_multiple_d_gemm_multiple_d.hpp index eacc5976d3ea3885a8294f1ee74a249f0a3d52a7..8234e29486b1af8ae0dd43e58d4be9b63b5ed34f 100644 --- a/include/ck/tensor_operation/gpu/device/device_batched_gemm_multiple_d_gemm_multiple_d.hpp +++ b/include/ck/tensor_operation/gpu/device/device_batched_gemm_multiple_d_gemm_multiple_d.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/include/ck/tensor_operation/gpu/device/device_batched_gemm_softmax_gemm.hpp b/include/ck/tensor_operation/gpu/device/device_batched_gemm_softmax_gemm.hpp index c1f85e575ce4b2e0d70c05010bfca6e9ed9b84a8..09259224e75cc67e092eeeaab1af793b050cc1f0 100644 --- a/include/ck/tensor_operation/gpu/device/device_batched_gemm_softmax_gemm.hpp +++ b/include/ck/tensor_operation/gpu/device/device_batched_gemm_softmax_gemm.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/include/ck/tensor_operation/gpu/device/device_batched_gemm_softmax_gemm_permute.hpp b/include/ck/tensor_operation/gpu/device/device_batched_gemm_softmax_gemm_permute.hpp index bde71806daa0543c41813f2afb0dd34f16bfb76e..be8105c96726dfb3d8931b376a668c3296040706 100644 --- a/include/ck/tensor_operation/gpu/device/device_batched_gemm_softmax_gemm_permute.hpp +++ b/include/ck/tensor_operation/gpu/device/device_batched_gemm_softmax_gemm_permute.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/include/ck/tensor_operation/gpu/device/device_batchnorm_backward.hpp b/include/ck/tensor_operation/gpu/device/device_batchnorm_backward.hpp index d39f3b7cbcfc9ccaccca5f755be747cc4896dc1b..2c0da692570b455da42540bb72793b9e629cb823 100644 --- a/include/ck/tensor_operation/gpu/device/device_batchnorm_backward.hpp +++ b/include/ck/tensor_operation/gpu/device/device_batchnorm_backward.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/include/ck/tensor_operation/gpu/device/device_batchnorm_forward.hpp b/include/ck/tensor_operation/gpu/device/device_batchnorm_forward.hpp index aa93dd9c19d77fbee7ccdafd1e41e4c4f64735cd..e3962e177ee824981c35fbe2ba04664e4124315b 100644 --- a/include/ck/tensor_operation/gpu/device/device_batchnorm_forward.hpp +++ b/include/ck/tensor_operation/gpu/device/device_batchnorm_forward.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/include/ck/tensor_operation/gpu/device/device_batchnorm_infer.hpp b/include/ck/tensor_operation/gpu/device/device_batchnorm_infer.hpp index 8a00fd9db3327a7c536b1ada0ca090bfa847a256..69103b6f44297ed65ed4233e3d458ac6927bcd16 100644 --- a/include/ck/tensor_operation/gpu/device/device_batchnorm_infer.hpp +++ b/include/ck/tensor_operation/gpu/device/device_batchnorm_infer.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/include/ck/tensor_operation/gpu/device/device_cgemm.hpp b/include/ck/tensor_operation/gpu/device/device_cgemm.hpp index aedae53800b03cb78e6af5b15c8c6dfcd9792eea..8484212118c796f68d1bf066eab7b8ddbb231797 100644 --- a/include/ck/tensor_operation/gpu/device/device_cgemm.hpp +++ b/include/ck/tensor_operation/gpu/device/device_cgemm.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #pragma once #include "device_base.hpp" diff --git a/include/ck/tensor_operation/gpu/device/device_contraction_multiple_d.hpp b/include/ck/tensor_operation/gpu/device/device_contraction_multiple_d.hpp index dbc525c099bca293d106ac6d3c37ac72b2d749cc..118ade8978643f10dcfafbefb552fe485d6697e4 100644 --- a/include/ck/tensor_operation/gpu/device/device_contraction_multiple_d.hpp +++ b/include/ck/tensor_operation/gpu/device/device_contraction_multiple_d.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/include/ck/tensor_operation/gpu/device/device_conv_bwd_data.hpp b/include/ck/tensor_operation/gpu/device/device_conv_bwd_data.hpp index 82054a3c9423f819d17a003b8c22844f0a2273fe..eb1b85ec822aa96822b5a288d55a5d3a0898f1e9 100644 --- a/include/ck/tensor_operation/gpu/device/device_conv_bwd_data.hpp +++ b/include/ck/tensor_operation/gpu/device/device_conv_bwd_data.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/include/ck/tensor_operation/gpu/device/device_conv_fwd.hpp b/include/ck/tensor_operation/gpu/device/device_conv_fwd.hpp index 4b9881088dde07a62e92c5f77a19edc9c4b7670f..4dc11dbefd73619c39febe31a91a953b25049f05 100644 --- a/include/ck/tensor_operation/gpu/device/device_conv_fwd.hpp +++ b/include/ck/tensor_operation/gpu/device/device_conv_fwd.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/include/ck/tensor_operation/gpu/device/device_conv_fwd_bias_activation.hpp b/include/ck/tensor_operation/gpu/device/device_conv_fwd_bias_activation.hpp index 5a627deeb2221f5e271532d45bc8a544c8cceeba..7d3845666cf2289f34b186a5aaa8ee616812f30a 100644 --- a/include/ck/tensor_operation/gpu/device/device_conv_fwd_bias_activation.hpp +++ b/include/ck/tensor_operation/gpu/device/device_conv_fwd_bias_activation.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/include/ck/tensor_operation/gpu/device/device_conv_fwd_bias_activation_add.hpp b/include/ck/tensor_operation/gpu/device/device_conv_fwd_bias_activation_add.hpp index cc139303c929ae999d182ad09e9a4ef33f5209e1..3a49ac632e7de745e78ad5670dc57f425ab9eb33 100644 --- a/include/ck/tensor_operation/gpu/device/device_conv_fwd_bias_activation_add.hpp +++ b/include/ck/tensor_operation/gpu/device/device_conv_fwd_bias_activation_add.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/include/ck/tensor_operation/gpu/device/device_elementwise.hpp b/include/ck/tensor_operation/gpu/device/device_elementwise.hpp index f9f913a7c1f1a6cdfdb41f075f2609e8e2407d07..db0e4bd83f46ad001492876ff8a712895a873c22 100644 --- a/include/ck/tensor_operation/gpu/device/device_elementwise.hpp +++ b/include/ck/tensor_operation/gpu/device/device_elementwise.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/include/ck/tensor_operation/gpu/device/device_elementwise_normalization.hpp b/include/ck/tensor_operation/gpu/device/device_elementwise_normalization.hpp index 9491a92247c71e3938d0f7d81cc0d82741b693de..c56a947ec9e11a0a5dce2cd39ce86d8e6a3efb8c 100644 --- a/include/ck/tensor_operation/gpu/device/device_elementwise_normalization.hpp +++ b/include/ck/tensor_operation/gpu/device/device_elementwise_normalization.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/include/ck/tensor_operation/gpu/device/device_gemm.hpp b/include/ck/tensor_operation/gpu/device/device_gemm.hpp index c0af6f80faf606a22678e516be994a75c1d56eca..adf909821dbbdde12526513a7b6218e8ebbd219d 100644 --- a/include/ck/tensor_operation/gpu/device/device_gemm.hpp +++ b/include/ck/tensor_operation/gpu/device/device_gemm.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/include/ck/tensor_operation/gpu/device/device_gemm_bias_e_permute.hpp b/include/ck/tensor_operation/gpu/device/device_gemm_bias_e_permute.hpp index 4c2161eaed5b9c7d7913686f44c360c044ca3b07..a7f42c3b35e237cb4ad9edcbcd137f453af59051 100644 --- a/include/ck/tensor_operation/gpu/device/device_gemm_bias_e_permute.hpp +++ b/include/ck/tensor_operation/gpu/device/device_gemm_bias_e_permute.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/include/ck/tensor_operation/gpu/device/device_gemm_multiple_d.hpp b/include/ck/tensor_operation/gpu/device/device_gemm_multiple_d.hpp index 9113bb7b7454443cce4871bf263433d4f907bfba..a44356dc2405bfe6ed448a8193d55e86bd29415c 100644 --- a/include/ck/tensor_operation/gpu/device/device_gemm_multiple_d.hpp +++ b/include/ck/tensor_operation/gpu/device/device_gemm_multiple_d.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/include/ck/tensor_operation/gpu/device/device_gemm_multiple_d_layernorm.hpp b/include/ck/tensor_operation/gpu/device/device_gemm_multiple_d_layernorm.hpp index a67a09b87416bf11cc285ab1b1d9dc684b17103f..0258858fe50e2d433e5ddfe007e8d06227b47c0c 100644 --- a/include/ck/tensor_operation/gpu/device/device_gemm_multiple_d_layernorm.hpp +++ b/include/ck/tensor_operation/gpu/device/device_gemm_multiple_d_layernorm.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/include/ck/tensor_operation/gpu/device/device_gemm_multiple_d_multiple_r.hpp b/include/ck/tensor_operation/gpu/device/device_gemm_multiple_d_multiple_r.hpp index f4881e32f620be9c31649085813a257e9e84a598..539e83f7cb8a5dfe99b2ecb3da491f9661330265 100644 --- a/include/ck/tensor_operation/gpu/device/device_gemm_multiple_d_multiple_r.hpp +++ b/include/ck/tensor_operation/gpu/device/device_gemm_multiple_d_multiple_r.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/include/ck/tensor_operation/gpu/device/device_gemm_reduce.hpp b/include/ck/tensor_operation/gpu/device/device_gemm_reduce.hpp index fcc088ca43d1b8c6224c0015e3eb7434038af4b2..eaa7671c6424b4ecf5ad87c4e790a9c478645b9b 100644 --- a/include/ck/tensor_operation/gpu/device/device_gemm_reduce.hpp +++ b/include/ck/tensor_operation/gpu/device/device_gemm_reduce.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/include/ck/tensor_operation/gpu/device/device_gemm_splitk.hpp b/include/ck/tensor_operation/gpu/device/device_gemm_splitk.hpp index c701bff57f8eb7db155e9abdfd3ab7210e7eeffd..6407aa7e09b81a4dc09dcb1bdca5701aa86d9e14 100644 --- a/include/ck/tensor_operation/gpu/device/device_gemm_splitk.hpp +++ b/include/ck/tensor_operation/gpu/device/device_gemm_splitk.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/include/ck/tensor_operation/gpu/device/device_gemm_streamk.hpp b/include/ck/tensor_operation/gpu/device/device_gemm_streamk.hpp new file mode 100644 index 0000000000000000000000000000000000000000..ed081ad7fc806780dfebf92fbdcbb1e7f99dfb86 --- /dev/null +++ b/include/ck/tensor_operation/gpu/device/device_gemm_streamk.hpp @@ -0,0 +1,64 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include +#include + +#include "device_base.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { + +template +struct DeviceGemmStreamK : public BaseOperator +{ + virtual std::unique_ptr MakeArgumentPointer(const void* p_a, + const void* p_b, + void* p_c, + ck::index_t M, + ck::index_t N, + ck::index_t K, + ck::index_t StrideA, + ck::index_t StrideB, + ck::index_t StrideC, + AElementwiseOperation a_element_op, + BElementwiseOperation b_element_op, + CElementwiseOperation c_element_op, + ck::index_t NumSKBlocks = 0) = 0; + + virtual std::unique_ptr MakeInvokerPointer() = 0; +}; + +template +using DeviceGemmStreamKPtr = std::unique_ptr>; + +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/include/ck/tensor_operation/gpu/device/device_grouped_contraction_multiple_d.hpp b/include/ck/tensor_operation/gpu/device/device_grouped_contraction_multiple_d.hpp index 173c613a325d8c594298ff6751060fecf2f8d453..ba81948440acd892a9877a7615f95642628b6347 100644 --- a/include/ck/tensor_operation/gpu/device/device_grouped_contraction_multiple_d.hpp +++ b/include/ck/tensor_operation/gpu/device/device_grouped_contraction_multiple_d.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/include/ck/tensor_operation/gpu/device/device_grouped_conv_bwd_data_multiple_d.hpp b/include/ck/tensor_operation/gpu/device/device_grouped_conv_bwd_data_multiple_d.hpp index 3350aec8d3174010cf0d926c3a91802ce3b8e7e7..7e4bca2bd66a55c70b6a9365e2f1debcc22bb119 100644 --- a/include/ck/tensor_operation/gpu/device/device_grouped_conv_bwd_data_multiple_d.hpp +++ b/include/ck/tensor_operation/gpu/device/device_grouped_conv_bwd_data_multiple_d.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/include/ck/tensor_operation/gpu/device/device_grouped_conv_bwd_weight.hpp b/include/ck/tensor_operation/gpu/device/device_grouped_conv_bwd_weight.hpp index 1258aed71c502c659f065d08e8b1a84922c26c83..ab9e6adb41963df1f9cd56ca2a7ed3056275defe 100644 --- a/include/ck/tensor_operation/gpu/device/device_grouped_conv_bwd_weight.hpp +++ b/include/ck/tensor_operation/gpu/device/device_grouped_conv_bwd_weight.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #pragma once @@ -27,17 +27,16 @@ struct DeviceGroupedConvBwdWeight : public BaseOperator MakeArgumentPointer(const void* p_in, void* p_wei, const void* p_out, - ck::index_t G, - ck::index_t N, - ck::index_t K, - ck::index_t C, - std::array input_spatial_lengths, - std::array filter_spatial_lengths, - std::array output_spatial_lengths, - std::array conv_filter_strides, - std::array conv_filter_dilations, - std::array input_left_pads, - std::array input_right_pads, + const std::array& a_g_n_c_wis_lengths, // input + const std::array& a_g_n_c_wis_strides, + const std::array& b_g_k_c_xs_lengths, // weight + const std::array& b_g_k_c_xs_strides, + const std::array& e_g_n_k_wos_lengths, // output + const std::array& e_g_n_k_wos_strides, + const std::array& conv_filter_strides, + const std::array& conv_filter_dilations, + const std::array& input_left_pads, + const std::array& input_right_pads, InElementwiseOperation in_element_op, WeiElementwiseOperation wei_element_op, OutElementwiseOperation out_element_op, diff --git a/include/ck/tensor_operation/gpu/device/device_grouped_conv_fwd.hpp b/include/ck/tensor_operation/gpu/device/device_grouped_conv_fwd.hpp index 644c7ee9a9107718dc609836e6cd8abe7f2dad21..025c43e75cc204b5bb8b3bdc1083d30aaf760bbc 100644 --- a/include/ck/tensor_operation/gpu/device/device_grouped_conv_fwd.hpp +++ b/include/ck/tensor_operation/gpu/device/device_grouped_conv_fwd.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/include/ck/tensor_operation/gpu/device/device_grouped_gemm_fixed_nk.hpp b/include/ck/tensor_operation/gpu/device/device_grouped_gemm_fixed_nk.hpp new file mode 100644 index 0000000000000000000000000000000000000000..fcb2ba6a4d7be839c11cf9794bb7beccf7845d3c --- /dev/null +++ b/include/ck/tensor_operation/gpu/device/device_grouped_gemm_fixed_nk.hpp @@ -0,0 +1,63 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include +#include + +#include "device_grouped_gemm.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { + +template +struct GroupedGemmKernelArgument +{ + const void* p_a_grid; + const void* p_b_grid; + std::array p_ds_grid; + void* p_e_grid; + + index_t M; + index_t N; + index_t K; + + index_t StrideA; + index_t StrideB; + std::array StrideDs; + index_t StrideE; +}; + +template +struct DeviceGroupedGemmFixedNK : DeviceGroupedGemm +{ + virtual void SetDeviceKernelArgs(BaseArgument* p_arg, const void* kernel_args) const = 0; + virtual size_t GetDeviceKernelArgSize(const BaseArgument* p_arg) const = 0; + virtual void SetKBatch(BaseArgument* p_arg, index_t k_batch) const = 0; +}; + +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/include/ck/tensor_operation/gpu/device/device_grouped_gemm_softmax_gemm_permute.hpp b/include/ck/tensor_operation/gpu/device/device_grouped_gemm_softmax_gemm_permute.hpp index b066a4458518b378313225e24817253922962d29..fae65097407cdb26df206308442874c0fc326fe6 100644 --- a/include/ck/tensor_operation/gpu/device/device_grouped_gemm_softmax_gemm_permute.hpp +++ b/include/ck/tensor_operation/gpu/device/device_grouped_gemm_softmax_gemm_permute.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/include/ck/tensor_operation/gpu/device/device_image_to_column.hpp b/include/ck/tensor_operation/gpu/device/device_image_to_column.hpp new file mode 100644 index 0000000000000000000000000000000000000000..631d5189dd5b57467795eae19093a21bf787ca34 --- /dev/null +++ b/include/ck/tensor_operation/gpu/device/device_image_to_column.hpp @@ -0,0 +1,70 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include + +#include "ck/tensor_operation/gpu/device/device_base.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { + +/** + * \brief Image to column. + * + * This Device operator converts image ([G, N, Di, Hi, Wi, C]) to the gemm + * problem([N * Do * Ho * Wo, Z * Y * X * C]). G must be equal to 1. + * + * \tparam NDimSpatial Number of spatial dimensions. + * \tparam InputLayout Input Layout. + * \tparam InputDataType Input Data Type. + * \tparam OutputDataType Output Data Type. + */ +template +struct DeviceImageToColumn : public BaseOperator +{ + + /** + * \brief Make argument pointer for image to column. + * + * \param p_in A pointer to the device memory of the input image. + * \param p_out A pointer to the device memory of the output. + * \param N Convolution batch size. + * \param C Convolution number of channels. + * \param input_spatial_lengths Input spatial lengths. + * \param filter_spatial_lengths Filter spatial lengths. + * \param output_spatial_lengths Output spatial lengths. + * \param input_g_n_c_wis_strides Input strides in order [G, N, C, D, H, W]. + * \param output_m_k_strides Output strides. + * \param conv_filter_strides Convolution filter strides. + * \param conv_filter_dilations Convolution filter dilations. + * \param input_left_pads Convolution left pads. + * \param input_right_pads Convolution right pads. + * \return Pointer to the argument. + */ + virtual std::unique_ptr + MakeArgumentPointer(const void* p_in, + void* p_out, + const ck::index_t N, + const ck::index_t C, + const std::array& input_spatial_lengths, + const std::array& filter_spatial_lengths, + const std::array& output_spatial_lengths, + const std::array& input_g_n_c_wis_strides, + const std::array& output_m_k_strides, + const std::array& conv_filter_strides, + const std::array& conv_filter_dilations, + const std::array& input_left_pads, + const std::array& input_right_pads) = 0; + + virtual std::unique_ptr MakeInvokerPointer() = 0; +}; + +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/include/ck/tensor_operation/gpu/device/device_max_pool_bwd.hpp b/include/ck/tensor_operation/gpu/device/device_max_pool_bwd.hpp new file mode 100644 index 0000000000000000000000000000000000000000..5a4a9cac1e8ea78b6683b1c4ce9629a8c87b9802 --- /dev/null +++ b/include/ck/tensor_operation/gpu/device/device_max_pool_bwd.hpp @@ -0,0 +1,33 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include + +#include "ck/tensor_operation/gpu/device/device_base.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { + +// For pooling which used indexable operation, such as MaxPool, MinPool...etc +template +struct DeviceMaxPoolBwd : public BaseOperator +{ + virtual std::unique_ptr + MakeArgumentPointer(const void* p_dout, + const void* p_indices, + void* p_din, + index_t dout_length, + index_t din_length, + std::vector window_lengths, + std::vector window_strides, + std::vector window_dilations) = 0; + + virtual std::unique_ptr MakeInvokerPointer() = 0; +}; + +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/include/ck/tensor_operation/gpu/device/device_multiple_reduce.hpp b/include/ck/tensor_operation/gpu/device/device_multiple_reduce.hpp index ee4b53e2fcc43f44aba29c3409963a1b94834cbc..f68022ca04de757fee1240ced4b504601b183d91 100644 --- a/include/ck/tensor_operation/gpu/device/device_multiple_reduce.hpp +++ b/include/ck/tensor_operation/gpu/device/device_multiple_reduce.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/include/ck/tensor_operation/gpu/device/device_normalization.hpp b/include/ck/tensor_operation/gpu/device/device_normalization.hpp index 03601ce8312ccea06d60236779aa78a9ea306e01..1f178f9fcb65ffdd7ab09146d94131c1d5c993f2 100644 --- a/include/ck/tensor_operation/gpu/device/device_normalization.hpp +++ b/include/ck/tensor_operation/gpu/device/device_normalization.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/include/ck/tensor_operation/gpu/device/device_permute.hpp b/include/ck/tensor_operation/gpu/device/device_permute.hpp index 9daa2be37338c8309bd2131f263d324fe15b7ab2..c994cf02c6aab5b85fcd2fcc29f2716de28b861d 100644 --- a/include/ck/tensor_operation/gpu/device/device_permute.hpp +++ b/include/ck/tensor_operation/gpu/device/device_permute.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/include/ck/tensor_operation/gpu/device/device_pool_fwd.hpp b/include/ck/tensor_operation/gpu/device/device_pool_fwd.hpp index e801e98a25e90b98de81f4b032df4c5140d86c57..62071c43a1a0275d2868199af95d33354a365db5 100644 --- a/include/ck/tensor_operation/gpu/device/device_pool_fwd.hpp +++ b/include/ck/tensor_operation/gpu/device/device_pool_fwd.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #pragma once @@ -17,6 +17,8 @@ template struct DevicePoolFwd : public BaseOperator @@ -25,13 +27,14 @@ struct DevicePoolFwd : public BaseOperator MakeArgumentPointer(const void* p_in_dev, void* p_out_dev, void* p_out_indices_dev, - std::vector input_lengths, - std::vector window_lengths, - std::vector output_lengths, - std::vector input_stride, - std::vector output_stride, - std::vector indices_stride, - std::vector window_strides, + std::vector input_n_c_wis_lengths, + std::vector window_xs_lengths, + std::vector output_n_c_wos_lengths, + std::vector input_n_c_wis_stride, + std::vector output_n_c_wis_stride, + std::vector indices_n_c_wis_stride, + std::vector window_xs_strides, + std::vector window_xs_dilations, std::vector input_left_pads, std::vector input_right_pads, std::vector pooling_dims) = 0; diff --git a/include/ck/tensor_operation/gpu/device/device_put_element.hpp b/include/ck/tensor_operation/gpu/device/device_put_element.hpp new file mode 100644 index 0000000000000000000000000000000000000000..17df2de37b68c94cbd9ed9091cdfcc106ad8e287 --- /dev/null +++ b/include/ck/tensor_operation/gpu/device/device_put_element.hpp @@ -0,0 +1,36 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include + +#include "ck/tensor_operation/gpu/device/device_base.hpp" +#include "ck/utility/reduction_enums.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { + +// output[indices] = input +template +struct DevicePutElement : public BaseOperator +{ + virtual std::unique_ptr + MakeArgumentPointer(const void* p_input, + const void* p_indices, + void* p_output, + index_t input_length, + index_t output_length, + ElementwiseOperation elementwise_op) = 0; + + virtual std::unique_ptr MakeInvokerPointer() = 0; +}; + +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/include/ck/tensor_operation/gpu/device/device_reduce.hpp b/include/ck/tensor_operation/gpu/device/device_reduce.hpp index c9209f2d7d681b44c8a27dd06626eb8a4b998a39..c2721b18455c6a46e5973fad32767ae43247f59b 100644 --- a/include/ck/tensor_operation/gpu/device/device_reduce.hpp +++ b/include/ck/tensor_operation/gpu/device/device_reduce.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/include/ck/tensor_operation/gpu/device/device_softmax.hpp b/include/ck/tensor_operation/gpu/device/device_softmax.hpp index 94f788e5177cd78d0ea917b0869c576ba8fe7bfb..1902fd09eec1ed01954b6cad47b0e97ff898eb40 100644 --- a/include/ck/tensor_operation/gpu/device/device_softmax.hpp +++ b/include/ck/tensor_operation/gpu/device/device_softmax.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #pragma once @@ -18,7 +18,8 @@ template + index_t Rank, + index_t NumReduceDim> struct DeviceSoftmax : public BaseOperator { // @@ -49,8 +50,6 @@ struct DeviceSoftmax : public BaseOperator AccElementwiseOp acc_elementwise_op) = 0; virtual std::unique_ptr MakeInvokerPointer() = 0; - virtual index_t GetRank() const = 0; - virtual index_t GetNumReduceDim() const = 0; }; template -using DeviceSoftmaxPtr = std::unique_ptr< - DeviceSoftmax>; + index_t Rank, + index_t NumReduceDim> +using DeviceSoftmaxPtr = std::unique_ptr>; } // namespace device } // namespace tensor_operation diff --git a/include/ck/tensor_operation/gpu/device/device_splitk_contraction_multiple_d.hpp b/include/ck/tensor_operation/gpu/device/device_splitk_contraction_multiple_d.hpp index f59e6093e2ae024ccbf8827082c755eba8cee88f..eeccd977ccbf4440a349845826d45d5a3274ad37 100644 --- a/include/ck/tensor_operation/gpu/device/device_splitk_contraction_multiple_d.hpp +++ b/include/ck/tensor_operation/gpu/device/device_splitk_contraction_multiple_d.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/include/ck/tensor_operation/gpu/device/gemm_specialization.hpp b/include/ck/tensor_operation/gpu/device/gemm_specialization.hpp index fc913e9ba03611dfbabea020bd4054c0b953726d..0bb45b18c3e19b2ec5f9347c1e811d8734ee45a9 100644 --- a/include/ck/tensor_operation/gpu/device/gemm_specialization.hpp +++ b/include/ck/tensor_operation/gpu/device/gemm_specialization.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/include/ck/tensor_operation/gpu/device/impl/device_avgpool3d_bwd_ndhwc_ndhwc.hpp b/include/ck/tensor_operation/gpu/device/impl/device_avgpool3d_bwd_ndhwc_ndhwc.hpp new file mode 100644 index 0000000000000000000000000000000000000000..3a2280a75c27d897972c6f3149b8a0f4e63da7a9 --- /dev/null +++ b/include/ck/tensor_operation/gpu/device/impl/device_avgpool3d_bwd_ndhwc_ndhwc.hpp @@ -0,0 +1,575 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include +#include + +#include "ck/tensor_description/tensor_descriptor.hpp" +#include "ck/tensor_description/tensor_descriptor_helper.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/reduction_operator_mapping.hpp" +#include "ck/tensor_operation/gpu/device/device_avgpool_bwd.hpp" +#include "ck/tensor_operation/gpu/grid/gridwise_2d_reduction_threadwise.hpp" +#include "ck/host_utility/device_prop.hpp" +#include "ck/host_utility/kernel_launch.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { + +// In and Din = [N, C, Di, Hi, Wi] +// Out and Dout = [N, C, Do, Ho, Wo] +// Out = AvgPoolFwd(In) +// Din = AvgPoolBwd(Dout) +// Pooling dimension = D, H, W +template +struct DeviceAvgPool3dBwd_NDHWC_NDHWC : public DeviceAvgPoolBwd<3, + DOutDataType, + DInDataType, + tensor_layout::convolution::NDHWC, + tensor_layout::convolution::NDHWC> +{ + static constexpr ck::index_t NDimSpatial = 3; + + static constexpr auto I0 = Number<0>{}; + static constexpr auto I1 = Number<1>{}; + + static constexpr ck::index_t M_BlockTileSize = MThreadClusterSize * MThreadSliceSize; + static constexpr ck::index_t K_BlockTileSize = KThreadClusterSize * KThreadSliceSize; + + static auto + Make3DGridDescriptor_Out_M_K_In_M(const std::vector& dout_n_c_wos_lengths, + const std::vector& din_n_c_wos_length, + const std::vector& dout_n_c_wos_strides, + const std::vector& din_n_c_wos_strides, + const std::vector& window_lengths, + const std::vector& window_strides, + const std::vector& window_dilations, + const std::vector& input_left_pads, + const std::vector& input_right_pads, + const std::vector& tildes) + { + index_t i_ztilde = tildes[0]; + index_t i_ytilde = tildes[1]; + index_t i_xtilde = tildes[2]; + + const index_t N = dout_n_c_wos_lengths[0]; + const index_t C = dout_n_c_wos_lengths[1]; + + const index_t Di = din_n_c_wos_length[2]; + const index_t Hi = din_n_c_wos_length[3]; + const index_t Wi = din_n_c_wos_length[4]; + + const index_t Do = dout_n_c_wos_lengths[2]; + const index_t Ho = dout_n_c_wos_lengths[3]; + const index_t Wo = dout_n_c_wos_lengths[4]; + + const index_t Z = window_lengths[0]; + const index_t Y = window_lengths[1]; + const index_t X = window_lengths[2]; + + const index_t InLeftPadD = input_left_pads[0]; + const index_t InLeftPadH = input_left_pads[1]; + const index_t InLeftPadW = input_left_pads[2]; + + const index_t InRightPadD = input_right_pads[0]; + const index_t InRightPadH = input_right_pads[1]; + const index_t InRightPadW = input_right_pads[2]; + + const index_t ConvStrideD = window_strides[0]; + const index_t ConvStrideH = window_strides[1]; + const index_t ConvStrideW = window_strides[2]; + + const index_t ConvDilationD = window_dilations[0]; + const index_t ConvDilationH = window_dilations[1]; + const index_t ConvDilationW = window_dilations[2]; + + const auto out_n_do_ho_wo_c_grid_desc = + make_naive_tensor_descriptor(make_tuple(N, Do, Ho, Wo, C), + make_tuple(dout_n_c_wos_strides[0], + dout_n_c_wos_strides[2], + dout_n_c_wos_strides[3], + dout_n_c_wos_strides[4], + dout_n_c_wos_strides[1])); + + const auto GcdStrideDilationD = math::gcd(ConvStrideD, ConvDilationD); + const auto GcdStrideDilationH = math::gcd(ConvStrideH, ConvDilationH); + const auto GcdStrideDilationW = math::gcd(ConvStrideW, ConvDilationW); + + const auto ZTilde = ConvStrideD / GcdStrideDilationD; + const auto YTilde = ConvStrideH / GcdStrideDilationH; + const auto XTilde = ConvStrideW / GcdStrideDilationW; + + const auto ZDot = math::integer_divide_ceil(Z, ZTilde); + const auto YDot = math::integer_divide_ceil(Y, YTilde); + const auto XDot = math::integer_divide_ceil(X, XTilde); + + const auto DTilde = Do + math::integer_divide_ceil(ConvDilationD * (Z - I1), ConvStrideD); + const auto HTilde = Ho + math::integer_divide_ceil(ConvDilationH * (Y - I1), ConvStrideH); + const auto WTilde = Wo + math::integer_divide_ceil(ConvDilationW * (X - I1), ConvStrideW); + + // only work on Tildes that contribute to non-padding area of input tensor + const auto IDTildeSliceBegin = math::integer_divide_floor( + math::max(I0, InLeftPadD - ConvDilationD * (ZTilde - I1)), ConvStrideD); + const auto IHTildeSliceBegin = math::integer_divide_floor( + math::max(I0, InLeftPadH - ConvDilationH * (YTilde - I1)), ConvStrideH); + const auto IWTildeSliceBegin = math::integer_divide_floor( + math::max(I0, InLeftPadW - ConvDilationW * (XTilde - I1)), ConvStrideW); + + const auto IDTildeSliceEnd = + math::min(DTilde, math::integer_divide_ceil(InLeftPadD + Di - I1, ConvStrideD) + I1); + const auto IHTildeSliceEnd = + math::min(HTilde, math::integer_divide_ceil(InLeftPadH + Hi - I1, ConvStrideH) + I1); + const auto IWTildeSliceEnd = + math::min(WTilde, math::integer_divide_ceil(InLeftPadW + Wi - I1, ConvStrideW) + I1); + + const auto DTildeSlice = IDTildeSliceEnd - IDTildeSliceBegin; + const auto HTildeSlice = IHTildeSliceEnd - IHTildeSliceBegin; + const auto WTildeSlice = IWTildeSliceEnd - IWTildeSliceBegin; + + // ReduceK is different for each Reduce + const auto ZDotSlice = math::integer_divide_ceil(Z - i_ztilde, ZTilde); + const auto YDotSlice = math::integer_divide_ceil(Y - i_ytilde, YTilde); + const auto XDotSlice = math::integer_divide_ceil(X - i_xtilde, XTilde); + + // Problem size of reduction kernel + const index_t MRaw = N * DTildeSlice * HTildeSlice * WTildeSlice * C; + const index_t MPad = math::integer_least_multiple(MRaw, M_BlockTileSize) - MRaw; + + const index_t KRaw = ZDotSlice * YDotSlice * XDotSlice; + const index_t KPad = math::integer_least_multiple(KRaw, K_BlockTileSize) - KRaw; + + // Out[ReduceM, ReduceK] + const auto out_n_dop_hop_wop_c_grid_desc = transform_tensor_descriptor( + out_n_do_ho_wo_c_grid_desc, + make_tuple(make_pass_through_transform(N), + make_pad_transform(Do, I0, I0), + make_pad_transform(Ho, I0, I0), + make_pad_transform(Wo, I0, I0), + make_pass_through_transform(C)), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}, Sequence<4>{}), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}, Sequence<4>{})); + + const auto out_n_zdot_dtilde_ydot_htilde_xdot_wtilde_c_grid_desc = + transform_tensor_descriptor( + out_n_dop_hop_wop_c_grid_desc, + make_tuple( + make_pass_through_transform(N), + make_embed_transform(make_tuple(ZDot, DTilde), + make_tuple(-ConvDilationD / GcdStrideDilationD, I1)), + make_embed_transform(make_tuple(YDot, HTilde), + make_tuple(-ConvDilationH / GcdStrideDilationH, I1)), + make_embed_transform(make_tuple(XDot, WTilde), + make_tuple(-ConvDilationW / GcdStrideDilationW, I1)), + make_pass_through_transform(C)), + make_tuple( + Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}, Sequence<4>{}), + make_tuple(Sequence<0>{}, + Sequence<1, 2>{}, + Sequence<3, 4>{}, + Sequence<5, 6>{}, + Sequence<7>{})); + + const auto + out_n_zdotslice_dtildeslice_ydotslice_htildeslice_xdotslice_wtildeslice_c_grid_desc = + transform_tensor_descriptor( + out_n_zdot_dtilde_ydot_htilde_xdot_wtilde_c_grid_desc, + make_tuple(make_pass_through_transform(N), + make_slice_transform(ZDot, I0, ZDotSlice), + make_slice_transform(DTilde, IDTildeSliceBegin, DTildeSlice), + make_slice_transform(YDot, I0, YDotSlice), + make_slice_transform(HTilde, IHTildeSliceBegin, HTildeSlice), + make_slice_transform(XDot, I0, XDotSlice), + make_slice_transform(WTilde, IWTildeSliceBegin, WTildeSlice), + make_pass_through_transform(C)), + make_tuple(Sequence<0>{}, + Sequence<1>{}, + Sequence<2>{}, + Sequence<3>{}, + Sequence<4>{}, + Sequence<5>{}, + Sequence<6>{}, + Sequence<7>{}), + make_tuple(Sequence<0>{}, + Sequence<1>{}, + Sequence<2>{}, + Sequence<3>{}, + Sequence<4>{}, + Sequence<5>{}, + Sequence<6>{}, + Sequence<7>{})); + + const auto out_grid_desc_reducemraw_reducekraw = transform_tensor_descriptor( + out_n_zdotslice_dtildeslice_ydotslice_htildeslice_xdotslice_wtildeslice_c_grid_desc, + make_tuple( + make_merge_transform(make_tuple(N, DTildeSlice, HTildeSlice, WTildeSlice, C)), + make_merge_transform(make_tuple(ZDotSlice, YDotSlice, XDotSlice))), + make_tuple(Sequence<0, 2, 4, 6, 7>{}, Sequence<1, 3, 5>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + + const auto out_grid_desc_reducem_reducek = transform_tensor_descriptor( + out_grid_desc_reducemraw_reducekraw, + make_tuple(make_right_pad_transform(MRaw, MPad), make_right_pad_transform(KRaw, KPad)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + + // In[ReduceM] + const auto in_n_di_hi_wi_c_grid_desc = + make_naive_tensor_descriptor(make_tuple(N, Di, Hi, Wi, C), + make_tuple(din_n_c_wos_strides[0], + din_n_c_wos_strides[2], + din_n_c_wos_strides[3], + din_n_c_wos_strides[4], + din_n_c_wos_strides[1])); + + const auto in_n_dip_hip_wip_c_grid_desc = transform_tensor_descriptor( + in_n_di_hi_wi_c_grid_desc, + make_tuple(make_pass_through_transform(N), + make_pad_transform(Di, InLeftPadD, InRightPadD), + make_pad_transform(Hi, InLeftPadH, InRightPadH), + make_pad_transform(Wi, InLeftPadW, InRightPadW), + make_pass_through_transform(C)), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}, Sequence<4>{}), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}, Sequence<4>{})); + + const auto in_n_ztilde_dtilde_ytilde_htilde_xtilde_wtilde_c_grid_desc = + transform_tensor_descriptor( + in_n_dip_hip_wip_c_grid_desc, + make_tuple(make_pass_through_transform(N), + make_embed_transform(make_tuple(XTilde, DTilde), + make_tuple(ConvDilationD, ConvStrideD)), + make_embed_transform(make_tuple(YTilde, HTilde), + make_tuple(ConvDilationH, ConvStrideH)), + make_embed_transform(make_tuple(XTilde, WTilde), + make_tuple(ConvDilationW, ConvStrideW)), + make_pass_through_transform(C)), + make_tuple( + Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}, Sequence<4>{}), + make_tuple(Sequence<0>{}, + Sequence<1, 2>{}, + Sequence<3, 4>{}, + Sequence<5, 6>{}, + Sequence<7>{})); + + const auto in_n_dtildeslice_htildeslice_wtildeslice_c_grid_desc = + transform_tensor_descriptor( + in_n_ztilde_dtilde_ytilde_htilde_xtilde_wtilde_c_grid_desc, + make_tuple(make_pass_through_transform(N), + make_freeze_transform(i_ztilde), + make_slice_transform(DTilde, IDTildeSliceBegin, DTildeSlice), + make_freeze_transform(i_ytilde), + make_slice_transform(HTilde, IHTildeSliceBegin, HTildeSlice), + make_freeze_transform(i_xtilde), + make_slice_transform(WTilde, IWTildeSliceBegin, WTildeSlice), + make_pass_through_transform(C)), + make_tuple(Sequence<0>{}, + Sequence<1>{}, + Sequence<2>{}, + Sequence<3>{}, + Sequence<4>{}, + Sequence<5>{}, + Sequence<6>{}, + Sequence<7>{}), + make_tuple(Sequence<0>{}, + Sequence<>{}, + Sequence<1>{}, + Sequence<>{}, + Sequence<2>{}, + Sequence<>{}, + Sequence<3>{}, + Sequence<4>{})); + + const auto in_grid_desc_reducemraw = transform_tensor_descriptor( + in_n_dtildeslice_htildeslice_wtildeslice_c_grid_desc, + make_tuple( + make_merge_transform(make_tuple(N, DTildeSlice, HTildeSlice, WTildeSlice, C))), + make_tuple(Sequence<0, 1, 2, 3, 4>{}), + make_tuple(Sequence<0>{})); + + const auto in_grid_desc_reducem = + transform_tensor_descriptor(in_grid_desc_reducemraw, + make_tuple(make_right_pad_transform(MRaw, MPad)), + make_tuple(Sequence<0>{}), + make_tuple(Sequence<0>{})); + + return make_tuple(out_grid_desc_reducem_reducek, in_grid_desc_reducem); + } + + using DoutDinGridDesc = decltype(Make3DGridDescriptor_Out_M_K_In_M({0, 0, 0, 0, 0}, + {0, 0, 0, 0, 0}, + {0, 0, 0, 0, 0}, + {0, 0, 0, 0, 0}, + {0, 0, 0}, + {0, 0, 0}, + {0, 0, 0}, + {0, 0, 0}, + {0, 0, 0}, + {0, 0, 0})); + + using DoutGridDesc_M_K = remove_cvref_t>; + using DinGridDesc_M = remove_cvref_t>; + + // FIXME + // for NDHWC, the dim C is the fastest dimension, and is not reduced. + // Hence, it is in M dimension for reduction kernel. + static constexpr index_t OutSrcInDstVectorDim = 0; // 0: M, 1: K + + using PassThrough = tensor_operation::element_wise::PassThrough; + using Div = tensor_operation::element_wise::UnaryDivide; + + using gridwise_reduce = GridwiseReduction_mk_to_m_threadwise; + + struct Argument : public BaseArgument + { + Argument(const DOutDataType* p_dout, + DInDataType* p_din, + std::vector dout_n_c_wos_lengths, + std::vector din_n_c_wos_length, + std::vector dout_n_c_wos_strides, + std::vector din_n_c_wos_strides, + std::vector window_lengths, + std::vector window_strides, + std::vector window_dilations, + std::vector input_left_pads, + std::vector input_right_pads) + : p_dout_grid_{p_dout}, + p_din_grid_{p_din}, + dout_n_c_wos_lengths_{dout_n_c_wos_lengths}, + din_n_c_wos_length_{din_n_c_wos_length}, + dout_n_c_wos_strides_{dout_n_c_wos_strides}, + din_n_c_wos_strides_{din_n_c_wos_strides}, + num_reduce_{1}, + div_element_op_{window_lengths[0] * window_lengths[1] * window_lengths[2]} + { + std::vector Tildes(NDimSpatial); + for(int i = 0; i < NDimSpatial; ++i) + { + int GcdStrideDilation = math::gcd(window_strides[i], window_dilations[i]); + Tildes[i] = window_strides[i] / GcdStrideDilation; + num_reduce_ *= Tildes[i]; + } + + for(index_t i_ztilde = 0; i_ztilde < Tildes[0]; ++i_ztilde) + { + for(index_t i_ytilde = 0; i_ytilde < Tildes[1]; ++i_ytilde) + { + for(index_t i_xtilde = 0; i_xtilde < Tildes[2]; ++i_xtilde) + { + // check slice is valid + const auto ZDotSlice = + math::integer_divide_ceil(window_lengths[0] - i_ztilde, Tildes[0]); + const auto YDotSlice = + math::integer_divide_ceil(window_lengths[1] - i_ytilde, Tildes[1]); + const auto XDotSlice = + math::integer_divide_ceil(window_lengths[2] - i_xtilde, Tildes[2]); + + if(ZDotSlice * YDotSlice * XDotSlice <= 0) + { + continue; + } + + const auto dout_din_grid_desc = + Make3DGridDescriptor_Out_M_K_In_M(dout_n_c_wos_lengths, + din_n_c_wos_length, + dout_n_c_wos_strides, + din_n_c_wos_strides, + window_lengths, + window_strides, + window_dilations, + input_left_pads, + input_right_pads, + {i_ztilde, i_ytilde, i_xtilde}); + + dout_grid_desc_m_k_container_.push_back(dout_din_grid_desc[I0]); + din_grid_desc_m_container_.push_back(dout_din_grid_desc[I1]); + } + } + } + } + + const DOutDataType* p_dout_grid_; + DInDataType* p_din_grid_; + std::vector dout_n_c_wos_lengths_; + std::vector din_n_c_wos_length_; + std::vector dout_n_c_wos_strides_; + std::vector din_n_c_wos_strides_; + + int num_reduce_; + std::vector dout_grid_desc_m_k_container_; + std::vector din_grid_desc_m_container_; + + Div div_element_op_; + }; + + struct Invoker : public BaseInvoker + { + float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{}) + { + float ave_time = 0; + + for(index_t i = 0; i < arg.num_reduce_; i++) + { + const auto kernel = kernel_reduce_threadwise; + + ck::index_t M = arg.dout_grid_desc_m_k_container_[i].GetLength(I0); + const index_t grid_size = (M / M_BlockTileSize); + + ave_time += launch_and_time_kernel(stream_config, + kernel, + dim3(grid_size), + dim3(BlockSize), + 0, + arg.dout_grid_desc_m_k_container_[i], + arg.din_grid_desc_m_container_[i], + PassThrough{}, + arg.div_element_op_, + float(1), + arg.p_dout_grid_, + nullptr, + float(0), + arg.p_din_grid_, + nullptr); + } + + return ave_time; + } + + float Run(const BaseArgument* p_arg, + const StreamConfig& stream_config = StreamConfig{}) override + { + return Run(*dynamic_cast(p_arg), stream_config); + } + }; + + static bool IsSupportedArgument(const Argument& arg) + { + constexpr index_t Rank = NDimSpatial + 2; + int doutFastestDim = -1; + int dinFastestDim = -1; + + for(int i = 0; i < Rank; ++i) + { + if(arg.dout_n_c_wos_strides_[i] == 1) + doutFastestDim = i; + if(arg.din_n_c_wos_strides_[i] == 1) + dinFastestDim = i; + } + + if(doutFastestDim == -1 || dinFastestDim == -1) + { + if constexpr(InSrcOutDstVectorSize != 1) + return false; + } + else + { + if(arg.dout_n_c_wos_lengths_[doutFastestDim] % InSrcOutDstVectorSize != 0) + return false; + if(arg.din_n_c_wos_length_[dinFastestDim] % InSrcOutDstVectorSize != 0) + return false; + } + + return true; + } + + bool IsSupportedArgument(const BaseArgument* p_arg) override + { + return IsSupportedArgument(*dynamic_cast(p_arg)); + } + + std::unique_ptr + MakeArgumentPointer(const void* p_dout, + void* p_din, + std::vector dout_n_c_wos_lengths, + std::vector din_n_c_wos_length, + std::vector dout_n_c_wos_strides, + std::vector din_n_c_wos_strides, + std::vector window_lengths, + std::vector window_strides, + std::vector window_dilations, + std::vector input_left_pads, + std::vector input_right_pads) override + { + constexpr index_t Rank = NDimSpatial + 2; + + if(dout_n_c_wos_strides.size() != Rank || din_n_c_wos_strides.size() != Rank || + dout_n_c_wos_lengths.size() != Rank || din_n_c_wos_length.size() != Rank) + throw std::runtime_error("dimension is incorrect"); + + if(window_lengths.size() != NDimSpatial || window_strides.size() != NDimSpatial || + window_dilations.size() != NDimSpatial || input_left_pads.size() != NDimSpatial || + input_right_pads.size() != NDimSpatial) + throw std::runtime_error("dimension is incorrect"); + + return std::make_unique(static_cast(p_dout), + static_cast(p_din), + dout_n_c_wos_lengths, + din_n_c_wos_length, + dout_n_c_wos_strides, + din_n_c_wos_strides, + window_lengths, + window_strides, + window_dilations, + input_left_pads, + input_right_pads); + } + + std::unique_ptr MakeInvokerPointer() override + { + return std::make_unique(Invoker{}); + } + + std::string GetTypeString() const override + { + auto str = std::stringstream(); + + // clang-format off + str << "DeviceAvgPool3dBwd<" << BlockSize << ","; + str << "M_C" << MThreadClusterSize << "_S" << MThreadSliceSize << ","; + str << "K_C" << KThreadClusterSize << "_S" << KThreadSliceSize << ","; + str <<"InSrcOutDstVectorSize_" << InSrcOutDstVectorSize << ">"; + // clang-format on + + return str.str(); + } +}; + +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/include/ck/tensor_operation/gpu/device/impl/device_batched_contraction_multiple_d_wmma_cshuffle.hpp b/include/ck/tensor_operation/gpu/device/impl/device_batched_contraction_multiple_d_wmma_cshuffle.hpp index 493822aeb27215406a8f7fad56021aca8e4246a7..4d599e8017991b4f1b1905de9acd23a219ac67f9 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_batched_contraction_multiple_d_wmma_cshuffle.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_batched_contraction_multiple_d_wmma_cshuffle.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/include/ck/tensor_operation/gpu/device/impl/device_batched_contraction_multiple_d_xdl_cshuffle.hpp b/include/ck/tensor_operation/gpu/device/impl/device_batched_contraction_multiple_d_xdl_cshuffle.hpp index 9bf8f5ccd9ccecf14e1e0fdd005c331124f32547..32c45bc57e0fa6735a24f83e7ecfb62755933690 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_batched_contraction_multiple_d_xdl_cshuffle.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_batched_contraction_multiple_d_xdl_cshuffle.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #pragma once @@ -57,7 +57,7 @@ __global__ void const Block2ETileMap block_2_etile_map) { #if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \ - defined(__gfx940__)) + defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__)) __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()]; const index_t num_blocks_per_batch = @@ -543,9 +543,13 @@ struct DeviceBatchedContractionMultipleD_Xdl_CShuffle EGridDesc_G_M_N e_grid_desc_g_m_n_; }; + using ComputeDataType = ADataType; + // GridwiseGemm using GridwiseGemm = GridwiseGemmMultipleD_xdl_cshuffle< - ADataType, // TODO: distinguish A/B datatype + ADataType, + BDataType, + ComputeDataType, AccDataType, CShuffleDataType, DsDataType, @@ -588,14 +592,18 @@ struct DeviceBatchedContractionMultipleD_Xdl_CShuffle LoopSched>; // desc for blockwise copy - using AGridDesc_AK0_M_AK1 = remove_cvref_t; - using BGridDesc_BK0_N_BK1 = remove_cvref_t; - using DsGridDesc_MBlock_MPerBlock_NBlock_NPerBlock = remove_cvref_t; - using EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock = remove_cvref_t; + using AGridDesc_AK0_M_AK1 = + remove_cvref_t; + using BGridDesc_BK0_N_BK1 = + remove_cvref_t; + using DsGridDesc_MBlock_MPerBlock_NBlock_NPerBlock = remove_cvref_t< + decltype(GridwiseGemm::MakeDsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( + DsGridDesc_M_N{}))>; + using EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock = + remove_cvref_t; // block-to-e-tile map using Block2ETileMap = @@ -840,8 +848,7 @@ struct DeviceBatchedContractionMultipleD_Xdl_CShuffle static bool IsSupportedArgument(const Argument& arg) { - if(!(ck::get_device_name() == "gfx908" || ck::get_device_name() == "gfx90a" || - ck::get_device_name() == "gfx940")) + if(!ck::is_xdl_supported()) { return false; } diff --git a/include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_e_permute_xdl.hpp b/include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_e_permute_xdl.hpp index 20184458699811ce484863f4327539b80c64872e..ba22cf0bf868bf353cc4c550bc9c17ad5c2b1d03 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_e_permute_xdl.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_e_permute_xdl.hpp @@ -75,7 +75,7 @@ __global__ void const Block2ETileMap block_2_etile_map) { #if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \ - defined(__gfx940__)) + defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__)) const index_t num_blocks_per_batch = __builtin_amdgcn_readfirstlane(get_grid_size() / batch_count); const index_t g_idx = __builtin_amdgcn_readfirstlane(get_block_1d_id() / num_blocks_per_batch); @@ -331,8 +331,13 @@ struct DeviceBatchedGemmEPermuteXdl : public DeviceBatchedGemmEPermute, // DsDataType, @@ -378,13 +383,16 @@ struct DeviceBatchedGemmEPermuteXdl : public DeviceBatchedGemmEPermute; - using AGridDesc_AK0_M_AK1 = remove_cvref_t; - using BGridDesc_BK0_N_BK1 = remove_cvref_t; + using AGridDesc_AK0_M_AK1 = + remove_cvref_t; + using BGridDesc_BK0_N_BK1 = + remove_cvref_t; - using EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock = decltype( - GridwiseGemm::MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(EGridDesc_M_N{})); + using EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock = + decltype(GridwiseGemm::MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( + EGridDesc_M_N{})); using Block2ETileMap = typename GridwiseGemm::DefaultBlock2ETileMap; // Argument @@ -571,6 +579,11 @@ struct DeviceBatchedGemmEPermuteXdl : public DeviceBatchedGemmEPermute{}, diff --git a/include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_gemm_xdl_cshuffle.hpp b/include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_gemm_xdl_cshuffle.hpp index 20e9920d935c12f54dd7ea89af3279b829892fe7..3dbe8c67226ec4e4f0dc5bf4ac02ac49ae73cf46 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_gemm_xdl_cshuffle.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_gemm_xdl_cshuffle.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #pragma once @@ -61,7 +61,7 @@ __global__ void const ComputeBasePtrOfStridedBatch compute_base_ptr_of_batch) { #if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \ - defined(__gfx940__)) + defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__)) __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()]; const index_t num_blocks_per_batch = __builtin_amdgcn_readfirstlane(get_grid_size() / batch_count); @@ -589,8 +589,7 @@ struct DeviceBatchedGemmGemm_Xdl_CShuffle : public DeviceBatchedGemmGemm; // desc for blockwise copy - using AGridDesc_AK0_M_AK1 = remove_cvref_t; - using BGridDesc_BK0_N_BK1 = remove_cvref_t; - using DsGridDesc_MBlock_MPerBlock_NBlock_NPerBlock = remove_cvref_t; - using EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock = remove_cvref_t; + using AGridDesc_AK0_M_AK1 = + remove_cvref_t; + using BGridDesc_BK0_N_BK1 = + remove_cvref_t; + using DsGridDesc_MBlock_MPerBlock_NBlock_NPerBlock = remove_cvref_t< + decltype(GridwiseGemm::MakeDsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( + DsGridDesc_M_N{}))>; + using EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock = + remove_cvref_t; // block-to-e-tile map using Block2ETileMap = @@ -580,8 +588,7 @@ struct DeviceBatchedGemmMultiD_Xdl : public DeviceBatchedGemmMultiD +#include + +#include "ck/utility/common_header.hpp" +#include "ck/tensor_description/tensor_descriptor.hpp" +#include "ck/tensor_description/tensor_descriptor_helper.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/device_batched_gemm_multi_d.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/grid/gridwise_gemm_dl_multiple_d.hpp" +#include "ck/host_utility/device_prop.hpp" +#include "ck/host_utility/kernel_launch.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { + +/* + * \tparam ComputePtrOffsetOfBatch Class that computes the base pointer offsets of A, B, C matrix + * given the batch. For example, ComputePtrOffsetOfStridedBatch() computes the offsets of evenly + * strided batched, but we can easily extend to other layouts. The returned offset can be either \p + * index_t or \p long_index_t. If it returns \p long_index_t, we are not subject to the 2GB + * limitations. + * + * \note Using \p ComputePtrOffsetOfBatch gives us the flexibility that 2 workgroups can compute 2 + * tiles from different matrices. Keep in mind that these 2 matrices can share the same grid + * descriptor (like in BatchedGEMM), or use their own grid descriptors (in GroupedGemm). \link + * impl/device_conv3d_fwd_xdl_ndhwc_kzyxc_ndhwk.hpp kernel_gemm_xdlops_v2r3_for_conv3d \endlink for + * \link DeviceConv3d \endlink uses the same concept, but currently does NOT encapsulate the + * computing of pointer offset into \p ComputePtrOffsetOfStridedBatch. + */ + +template +__global__ void +#if CK_USE_LAUNCH_BOUNDS + __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) +#endif + kernel_gemm_dl_multiple_d( + const ABDataType* __restrict__ p_a_grid, + const ABDataType* __restrict__ p_b_grid, + DsPointer p_ds_grid, + EDataType* __restrict__ p_e_grid, + const index_t batch_count, + const AElementwiseOperation a_element_op, + const BElementwiseOperation b_element_op, + const CDEElementwiseOperation cde_element_op, + const AGridDesc_K0_M0_M1_K1 a_grid_desc_k0_m0_m1_k1, + const BGridDesc_K0_N0_N1_K1 b_grid_desc_k0_n0_n1_k1, + const DsGridDesc_M0_M10_M11_N0_N10_N11 ds_grid_desc_m0_m10_m11_n0_n10_n11, + const CGridDesc_M0_M10_M11_N0_N10_N11 e_grid_desc_m0_m10_m11_n0_n10_n11, + const ComputePtrOffsetOfBatch compute_ptr_offset_of_batch, + const Block2CTileMap block_2_ctile_map) +{ +#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx906__) || defined(__gfx908__) || \ + defined(__gfx90a__) || defined(__gfx940__) || defined(__gfx1030__) || defined(__gfx1100__) || \ + defined(__gfx1101__) || defined(__gfx1102__)) + + const index_t num_blocks_per_batch = + __builtin_amdgcn_readfirstlane(get_grid_size() / batch_count); + const index_t g_idx = __builtin_amdgcn_readfirstlane(get_block_1d_id() / num_blocks_per_batch); + + const long_index_t a_batch_offset = __builtin_amdgcn_readfirstlane( + static_cast(compute_ptr_offset_of_batch.GetAPtrOffset(g_idx))); + const long_index_t b_batch_offset = __builtin_amdgcn_readfirstlane( + static_cast(compute_ptr_offset_of_batch.GetBPtrOffset(g_idx))); + const long_index_t e_batch_offset = __builtin_amdgcn_readfirstlane( + static_cast(compute_ptr_offset_of_batch.GetEPtrOffset(g_idx))); + + const auto ds_batch_offset = compute_ptr_offset_of_batch.GetDsPtrOffset(g_idx); + + __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()]; + + DsPointer p_ds_grid_grp; + + static constexpr index_t NumDTensor = DsGridDesc_M0_M10_M11_N0_N10_N11::Size(); + + static_for<0, NumDTensor, 1>{}( + [&](auto i) { p_ds_grid_grp(i) = p_ds_grid[i] + ds_batch_offset[i]; }); + + GridwiseGemm::Run(p_a_grid + a_batch_offset, + p_b_grid + b_batch_offset, + p_ds_grid_grp, + p_e_grid + e_batch_offset, + p_shared, + a_element_op, + b_element_op, + cde_element_op, + a_grid_desc_k0_m0_m1_k1, + b_grid_desc_k0_n0_n1_k1, + ds_grid_desc_m0_m10_m11_n0_n10_n11, + e_grid_desc_m0_m10_m11_n0_n10_n11, + block_2_ctile_map, + integral_constant{}, + integral_constant{}); +#else + ignore = p_a_grid; + ignore = p_b_grid; + ignore = p_ds_grid; + ignore = p_e_grid; + ignore = batch_count; + ignore = a_element_op; + ignore = b_element_op; + ignore = cde_element_op; + ignore = a_grid_desc_k0_m0_m1_k1; + ignore = b_grid_desc_k0_n0_n1_k1; + ignore = ds_grid_desc_m0_m10_m11_n0_n10_n11; + ignore = e_grid_desc_m0_m10_m11_n0_n10_n11; + ignore = compute_ptr_offset_of_batch; + ignore = block_2_ctile_map; + +#endif +} + +template && + is_same_v, + bool> = false> +struct DeviceBatchedGemmMultipleD_Dl : public DeviceBatchedGemmMultiD + +{ + using DeviceOp = DeviceBatchedGemmMultipleD_Dl; + static constexpr index_t NumDTensor = DsDataType::Size(); + + static constexpr auto I0 = Number<0>{}; + static constexpr auto I1 = Number<1>{}; + static constexpr auto I2 = Number<2>{}; + static constexpr auto I3 = Number<3>{}; + static constexpr auto I4 = Number<4>{}; + static constexpr auto I5 = Number<5>{}; + + static constexpr auto K1Number = Number{}; + + static auto MakeAGridDescriptor_K0_M_K1(index_t M, index_t K, index_t StrideA) + { + const index_t K0 = K / K1; + + const auto a_grid_desc_m_k = [&]() { + if constexpr(is_same::value) + { + return make_naive_tensor_descriptor(make_tuple(M, K), make_tuple(StrideA, I1)); + } + else if constexpr(is_same::value) + { + return make_naive_tensor_descriptor(make_tuple(M, K), make_tuple(I1, StrideA)); + } + }(); + + if constexpr(GemmSpec == GemmSpecialization::MNPadding) + { + const auto PadM = (MPerBlock - M % MPerBlock) % MPerBlock; + + return transform_tensor_descriptor( + a_grid_desc_m_k, + make_tuple(make_unmerge_transform(make_tuple(K0, K1Number)), + make_right_pad_transform(M, PadM)), + make_tuple(Sequence<1>{}, Sequence<0>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + } + else + { + return transform_tensor_descriptor( + a_grid_desc_m_k, + make_tuple(make_unmerge_transform(make_tuple(K0, K1Number)), + make_pass_through_transform(M)), + make_tuple(Sequence<1>{}, Sequence<0>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + } + } + + static auto MakeBGridDescriptor_K0_N_K1(index_t K, index_t N, index_t StrideB) + { + const index_t K0 = K / K1; + + const auto b_grid_desc_k_n = [&]() { + if constexpr(is_same::value) + { + return make_naive_tensor_descriptor(make_tuple(K, N), make_tuple(StrideB, I1)); + } + else if constexpr(is_same::value) + { + return make_naive_tensor_descriptor(make_tuple(K, N), make_tuple(I1, StrideB)); + } + }(); + + if constexpr(GemmSpec == GemmSpecialization::MNPadding) + { + const auto PadN = (NPerBlock - N % NPerBlock) % NPerBlock; + + return transform_tensor_descriptor( + b_grid_desc_k_n, + make_tuple(make_unmerge_transform(make_tuple(K0, K1Number)), + make_right_pad_transform(N, PadN)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + } + else + { + return transform_tensor_descriptor( + b_grid_desc_k_n, + make_tuple(make_unmerge_transform(make_tuple(K0, K1Number)), + make_pass_through_transform(N)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + } + } + + template + static auto MakeEGridDescriptor_M_N(index_t M, index_t N, index_t StrideE) + { + const auto c_grid_desc_m_n = [&]() { + if constexpr(is_same::value) + { + return make_naive_tensor_descriptor(make_tuple(M, N), make_tuple(StrideE, I1)); + } + else if constexpr(is_same::value) + { + return make_naive_tensor_descriptor(make_tuple(M, N), make_tuple(I1, StrideE)); + } + }(); + + if constexpr(GemmSpec == GemmSpecialization::MNPadding) + { + const auto PadM = (MPerBlock - M % MPerBlock) % MPerBlock; + const auto PadN = (NPerBlock - N % NPerBlock) % NPerBlock; + + return transform_tensor_descriptor( + c_grid_desc_m_n, + make_tuple(make_right_pad_transform(M, PadM), make_right_pad_transform(N, PadN)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + } + else + { + + return transform_tensor_descriptor( + c_grid_desc_m_n, + make_tuple(make_pass_through_transform(M), make_pass_through_transform(N)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + } + } + + static auto MakeDsGridDescriptor_M_N(const std::array& MRaws, + const std::array& NRaws, + const std::array& DsStride) + { + return generate_tuple( + [&](auto i) { + using DLayout = remove_cvref_t>; + + return DeviceOp::MakeEGridDescriptor_M_N(MRaws[i], NRaws[i], DsStride[i]); + }, + Number{}); + } + + using AGridDesc_K0_M_K1 = decltype(MakeAGridDescriptor_K0_M_K1(1, 1, 1)); + using BGridDesc_K0_N_K1 = decltype(MakeBGridDescriptor_K0_N_K1(1, 1, 1)); + using DsGridDesc_M_N = decltype(MakeDsGridDescriptor_M_N({}, {}, {})); + using EGridDesc_M_N = decltype(MakeEGridDescriptor_M_N(1, 1, 1)); + + struct ComputePtrOffsetOfStridedBatch + { + ComputePtrOffsetOfStridedBatch(index_t BatchStrideA, + index_t BatchStrideB, + std::array BatchStrideDs, + index_t BatchStrideE) + : BatchStrideA_(BatchStrideA), + BatchStrideB_(BatchStrideB), + BatchStrideDs_(BatchStrideDs), + BatchStrideE_(BatchStrideE) + { + } + + __host__ __device__ constexpr long_index_t GetAPtrOffset(index_t g_idx) const + { + return g_idx * static_cast(BatchStrideA_); + } + + __host__ __device__ constexpr long_index_t GetBPtrOffset(index_t g_idx) const + { + return g_idx * static_cast(BatchStrideB_); + } + + __host__ __device__ constexpr auto GetDsPtrOffset(index_t g_idx) const + { + std::array ds_offset; + static_for<0, NumDTensor, 1>{}([&](auto i) { + ds_offset[i] = g_idx * static_cast(BatchStrideDs_[i]); + }); + return ds_offset; + } + + __host__ __device__ constexpr long_index_t GetEPtrOffset(index_t g_idx) const + { + return g_idx * static_cast(BatchStrideE_); + } + + private: + index_t BatchStrideA_; + index_t BatchStrideB_; + std::array BatchStrideDs_; + index_t BatchStrideE_; + }; + + // GridwiseGemm + using GridwiseGemm = + GridwiseGemmDlMultipleD_km_kn_mn; + + using AGridDesc_K0_M0_M1_K1 = + decltype(GridwiseGemm::MakeAGridDescriptor_K0_M0_M1_K1(AGridDesc_K0_M_K1{})); + using BGridDesc_K0_N0_N1_K1 = + decltype(GridwiseGemm::MakeBGridDescriptor_K0_N0_N1_K1(BGridDesc_K0_N_K1{})); + using DsGridDesc_M0_M10_M11_N0_N10_N11 = + decltype(GridwiseGemm::MakeDsGridDescriptor_M0_M10_M11_N0_N10_N11(DsGridDesc_M_N{})); + using EGridDesc_M0_M10_M11_N0_N10_N11 = + decltype(GridwiseGemm::MakeCGridDescriptor_M0_M10_M11_N0_N10_N11(EGridDesc_M_N{})); + using DefaultBlock2CTileMap = + decltype(GridwiseGemm::MakeDefaultBlock2CTileMap(EGridDesc_M_N{})); + + // Argument + struct Argument : public BaseArgument + { + Argument(const void* p_a_grid, + const void* p_b_grid, + std::array p_ds_grid, + void* p_e_grid, + index_t M, + index_t N, + index_t K, + index_t Batch, + index_t StrideA, + index_t StrideB, + std::array StrideDs, + index_t StrideE, + index_t BatchStrideA, + index_t BatchStrideB, + const std::array& BatchStrideDs, + index_t BatchStrideE, + AElementwiseOperation a_element_op, + BElementwiseOperation b_element_op, + CDEElementwiseOperation cde_element_op) + : p_a_grid_{static_cast(p_a_grid)}, + p_b_grid_{static_cast(p_b_grid)}, + p_ds_grid_{}, + p_e_grid_{static_cast(p_e_grid)}, + K_(K), + Batch_(Batch), + a_grid_desc_k0_m0_m1_k1_{}, + b_grid_desc_k0_n0_n1_k1_{}, + e_grid_desc_m0_m10_m11_n0_n10_n11_{}, + compute_ptr_offset_of_batch_{BatchStrideA, BatchStrideB, BatchStrideDs, BatchStrideE}, + block_2_ctile_map_{}, + a_element_op_{a_element_op}, + b_element_op_{b_element_op}, + cde_element_op_{cde_element_op} + { + a_grid_desc_k0_m_k1_ = + DeviceBatchedGemmMultipleD_Dl::MakeAGridDescriptor_K0_M_K1(M, K, StrideA); + b_grid_desc_k0_n_k1_ = + DeviceBatchedGemmMultipleD_Dl::MakeBGridDescriptor_K0_N_K1(K, N, StrideB); + static_for<0, NumDTensor, 1>{}([&](auto i) { + using DLayout = remove_cvref_t>; + using DDataType = remove_cvref_t>; + + // D pointer + p_ds_grid_(i) = static_cast(p_ds_grid[i]); + + // D desc + ds_grid_desc_m_n_(i) = + DeviceOp::MakeEGridDescriptor_M_N(M, N, StrideDs[i]); + }); + e_grid_desc_m_n_ = + DeviceBatchedGemmMultipleD_Dl::MakeEGridDescriptor_M_N(M, N, StrideE); + + if(GridwiseGemm::CheckValidity( + a_grid_desc_k0_m_k1_, b_grid_desc_k0_n_k1_, e_grid_desc_m_n_)) + { + a_grid_desc_k0_m0_m1_k1_ = + GridwiseGemm::MakeAGridDescriptor_K0_M0_M1_K1(a_grid_desc_k0_m_k1_); + b_grid_desc_k0_n0_n1_k1_ = + GridwiseGemm::MakeBGridDescriptor_K0_N0_N1_K1(b_grid_desc_k0_n_k1_); + + ds_grid_desc_m0_m10_m11_n0_n10_n11_ = + GridwiseGemm::MakeDsGridDescriptor_M0_M10_M11_N0_N10_N11(ds_grid_desc_m_n_); + + e_grid_desc_m0_m10_m11_n0_n10_n11_ = + GridwiseGemm::MakeCGridDescriptor_M0_M10_M11_N0_N10_N11(e_grid_desc_m_n_); + + block_2_ctile_map_ = GridwiseGemm::MakeDefaultBlock2CTileMap(e_grid_desc_m_n_); + } + } + + // private: + const ADataType* p_a_grid_; + const BDataType* p_b_grid_; + typename GridwiseGemm::DsGridPointer p_ds_grid_; + EDataType* p_e_grid_; + + index_t K_; + + // Batch + index_t Batch_; + + AGridDesc_K0_M_K1 a_grid_desc_k0_m_k1_; + BGridDesc_K0_N_K1 b_grid_desc_k0_n_k1_; + DsGridDesc_M_N ds_grid_desc_m_n_; + EGridDesc_M_N e_grid_desc_m_n_; + + AGridDesc_K0_M0_M1_K1 a_grid_desc_k0_m0_m1_k1_; + BGridDesc_K0_N0_N1_K1 b_grid_desc_k0_n0_n1_k1_; + DsGridDesc_M0_M10_M11_N0_N10_N11 ds_grid_desc_m0_m10_m11_n0_n10_n11_; + EGridDesc_M0_M10_M11_N0_N10_N11 e_grid_desc_m0_m10_m11_n0_n10_n11_; + + // for calculating batch offset + ComputePtrOffsetOfStridedBatch compute_ptr_offset_of_batch_; + + DefaultBlock2CTileMap block_2_ctile_map_; + + // TODO: unused since gridwise_gemm_dl_v1r3 does NOT support prologue for the time being. + AElementwiseOperation a_element_op_; + BElementwiseOperation b_element_op_; + CDEElementwiseOperation cde_element_op_; + }; + + // Invoker + struct Invoker : public BaseInvoker + { + using Argument = DeviceBatchedGemmMultipleD_Dl::Argument; + + float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{}) + { + { + std::cout << "arg.a_grid_desc_k0_m0_m1_k1_{" + << arg.a_grid_desc_k0_m_k1_.GetLength(I0) << ", " + << arg.a_grid_desc_k0_m_k1_.GetLength(I1) << ", " + << arg.a_grid_desc_k0_m_k1_.GetLength(I2) << "}" << std::endl; + + std::cout << "arg.b_grid_desc_k0_n0_n1_k1_{" + << arg.b_grid_desc_k0_n_k1_.GetLength(I0) << ", " + << arg.b_grid_desc_k0_n_k1_.GetLength(I1) << ", " + << arg.b_grid_desc_k0_n_k1_.GetLength(I2) << "}" << std::endl; + + std::cout << "arg.e_grid_desc_m_n_{ " << arg.e_grid_desc_m_n_.GetLength(I0) << ", " + << arg.e_grid_desc_m_n_.GetLength(I1) << "}" << std::endl; + } + + if(!GridwiseGemm::CheckValidity( + arg.a_grid_desc_k0_m_k1_, arg.b_grid_desc_k0_n_k1_, arg.e_grid_desc_m_n_)) + { + throw std::runtime_error( + "wrong! GridwiseGemmDlMultipleD_km_kn_mn has invalid setting"); + } + + const index_t grid_size = + GridwiseGemm::CalculateGridSize(arg.e_grid_desc_m_n_.GetLength(I0), + arg.e_grid_desc_m_n_.GetLength(I1)) * + arg.Batch_; + + auto launch_kernel = [&](auto has_main_k_block_loop, + auto has_double_tail_k_block_loop) { + constexpr bool has_main_loop = has_main_k_block_loop.value; + constexpr bool has_double_loop = has_double_tail_k_block_loop.value; + + const auto kernel = + kernel_gemm_dl_multiple_d; + + return launch_and_time_kernel(stream_config, + kernel, + dim3(grid_size), + dim3(BlockSize), + 0, + arg.p_a_grid_, + arg.p_b_grid_, + arg.p_ds_grid_, + arg.p_e_grid_, + arg.Batch_, + arg.a_element_op_, + arg.b_element_op_, + arg.cde_element_op_, + arg.a_grid_desc_k0_m0_m1_k1_, + arg.b_grid_desc_k0_n0_n1_k1_, + arg.ds_grid_desc_m0_m10_m11_n0_n10_n11_, + arg.e_grid_desc_m0_m10_m11_n0_n10_n11_, + arg.compute_ptr_offset_of_batch_, + arg.block_2_ctile_map_); + }; + + const auto K0 = arg.a_grid_desc_k0_m0_m1_k1_.GetLength(I0); + const bool has_main_k_block_loop = GridwiseGemm::CalculateHasMainKBlockLoop(K0); + const bool has_double_tail_k_block_loop = + GridwiseGemm::CalculateHasDoubleTailKBlockLoop(K0); + + if(has_main_k_block_loop && has_double_tail_k_block_loop) + { + return launch_kernel(integral_constant{}, + integral_constant{}); + } + else if(has_main_k_block_loop && !has_double_tail_k_block_loop) + { + return launch_kernel(integral_constant{}, + integral_constant{}); + } + else if(!has_main_k_block_loop && has_double_tail_k_block_loop) + { + return launch_kernel(integral_constant{}, + integral_constant{}); + } + else + { + return launch_kernel(integral_constant{}, + integral_constant{}); + } + } + + // polymorphic + float Run(const BaseArgument* p_arg, + const StreamConfig& stream_config = StreamConfig{}) override + { + return Run(*dynamic_cast(p_arg), stream_config); + } + }; + + static constexpr bool IsValidCompilationParameter() + { + // TODO: properly implement this check + return true; + } + + static bool IsSupportedArgument(const Argument& arg) + { + // TODO: Enable for gfx90a after complier fix + if(ck::get_device_name() == "gfx906" || ck::get_device_name() == "gfx90a" || + ck::get_device_name() == "gfx908" || ck::get_device_name() == "gfx1030" || + ck::get_device_name() == "gfx940" || ck::get_device_name() == "gfx1100" || + ck::get_device_name() == "gfx1101" || ck::get_device_name() == "gfx1102") + { + bool pass = true; + pass = pass && arg.K_ % K1 == 0; + + pass = pass && GridwiseGemm::CheckValidity(arg.a_grid_desc_k0_m_k1_, + arg.b_grid_desc_k0_n_k1_, + arg.e_grid_desc_m_n_); + + return pass; + } + else + { + return false; + } + } + + // polymorphic + bool IsSupportedArgument(const BaseArgument* p_arg) override + { + return IsSupportedArgument(*dynamic_cast(p_arg)); + } + + static auto MakeArgument(const void* p_a, + const void* p_b, + std::array p_ds, + void* p_e, + index_t M, + index_t N, + index_t K, + index_t Batch, + index_t StrideA, + index_t StrideB, + std::array StrideDs, + index_t StrideE, + index_t BatchStrideA, + index_t BatchStrideB, + const std::array& BatchStrideDs, + index_t BatchStrideE, + AElementwiseOperation a_element_op, + BElementwiseOperation b_element_op, + CDEElementwiseOperation cde_element_op) + { + return Argument{p_a, + p_b, + p_ds, + p_e, + M, + N, + K, + Batch, + StrideA, + StrideB, + StrideDs, + StrideE, + BatchStrideA, + BatchStrideB, + BatchStrideDs, + BatchStrideE, + a_element_op, + b_element_op, + cde_element_op}; + } + + static auto MakeInvoker() { return Invoker{}; } + + // polymorphic + std::unique_ptr + MakeArgumentPointer(const void* p_a, + const void* p_b, + const std::array& p_ds, + void* p_e, + index_t M, + index_t N, + index_t K, + index_t Batch, + index_t StrideA, + index_t StrideB, + const std::array& StrideDs, + index_t StrideE, + index_t BatchStrideA, + index_t BatchStrideB, + const std::array& BatchStrideDs, + index_t BatchStrideE, + AElementwiseOperation a_element_op, + BElementwiseOperation b_element_op, + CDEElementwiseOperation cde_element_op) override + { + return std::make_unique(p_a, + p_b, + p_ds, + p_e, + M, + N, + K, + Batch, + StrideA, + StrideB, + StrideDs, + StrideE, + BatchStrideA, + BatchStrideB, + BatchStrideDs, + BatchStrideE, + a_element_op, + b_element_op, + cde_element_op); + } + + // polymorphic + std::unique_ptr MakeInvokerPointer() override + { + return std::make_unique(Invoker{}); + } + + // polymorphic + std::string GetTypeString() const override + { + auto str = std::stringstream(); + + // clang-format off + str << "DeviceBatchedGemmMultipleD_Dl" + << "<" + << BlockSize << ", " + << MPerBlock << ", " + << NPerBlock << ", " + << K0PerBlock << ", " + << K1 << ", " + << M1PerThread << ", " + << N1PerThread << ", " + << KPerThread + << ">"; + // clang-format on + + return str.str(); + } +}; + +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_multiple_d_gemm_multiple_d_xdl_cshuffle.hpp b/include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_multiple_d_gemm_multiple_d_xdl_cshuffle.hpp index 196dc86da15411e95a17ea05eb484f716561c383..9455ec48b33419d49837a425eca08188216bb729 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_multiple_d_gemm_multiple_d_xdl_cshuffle.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_multiple_d_gemm_multiple_d_xdl_cshuffle.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #pragma once @@ -69,7 +69,7 @@ __global__ void const ComputeBasePtrOfStridedBatch compute_base_ptr_of_batch) { #if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \ - defined(__gfx940__)) + defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__)) __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()]; const index_t num_blocks_per_batch = __builtin_amdgcn_readfirstlane(get_grid_size() / batch_count); @@ -196,6 +196,8 @@ template ; - using A0GridDesc_AK0_M_AK1 = remove_cvref_t; - using B0GridDesc_BK0_N_BK1 = remove_cvref_t; - using B1GridDesc_BK0_N_BK1 = remove_cvref_t; + using A0GridDesc_AK0_M_AK1 = + remove_cvref_t; + using B0GridDesc_BK0_N_BK1 = + remove_cvref_t; + using B1GridDesc_BK0_N_BK1 = + remove_cvref_t; // Argument struct Argument : public BaseArgument @@ -805,8 +812,7 @@ struct DeviceBatchedGemmMultipleDGemmMultipleD_Xdl_CShuffle static bool IsSupportedArgument(const Argument& arg) { - if(!(ck::get_device_name() == "gfx908" || ck::get_device_name() == "gfx90a" || - ck::get_device_name() == "gfx940")) + if(!ck::is_xdl_supported()) { return false; } diff --git a/include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_reduce_xdl_cshuffle.hpp b/include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_reduce_xdl_cshuffle.hpp index ef9b90ba7cde65bf17336325be3c76e33c758e68..c4567108c9b2904f9a2b7e20edd32f4e4bd0d97b 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_reduce_xdl_cshuffle.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_reduce_xdl_cshuffle.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #pragma once @@ -60,7 +60,7 @@ __global__ void const Block2CTileMap block_2_ctile_map) { #if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \ - defined(__gfx940__)) + defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__)) const index_t num_blocks_per_batch = __builtin_amdgcn_readfirstlane(get_grid_size() / batch_count); const index_t g_idx = __builtin_amdgcn_readfirstlane(get_block_1d_id() / num_blocks_per_batch); @@ -801,6 +801,11 @@ struct DeviceBatchedGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<0, ReduceO static bool IsSupportedArgument(const Argument& arg) { + if(!ck::is_xdl_supported()) + { + return false; + } + return GridwiseGemm::CheckValidity(arg.a_grid_desc_ak0_m_ak1_, arg.b_grid_desc_bk0_n_bk1_, arg.c_grid_desc_m_n_, diff --git a/include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_softmax_gemm_permute_xdl_cshuffle.hpp b/include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_softmax_gemm_permute_xdl_cshuffle.hpp index 0c6c0ef7ad2ded496c0f03618a64dac8a2783b81..c38b7e1c5bc322a1765fdff02c0f26db205bb5e2 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_softmax_gemm_permute_xdl_cshuffle.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_softmax_gemm_permute_xdl_cshuffle.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #pragma once @@ -68,7 +68,7 @@ __global__ void const C0MatrixMask c0_matrix_mask) { #if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \ - defined(__gfx940__)) + defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__)) __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()]; const index_t num_blocks_per_batch = __builtin_amdgcn_readfirstlane(get_grid_size() / batch_count); @@ -197,7 +197,8 @@ template + int D0sTransferSrcScalarPerVector = 4, + LoopScheduler LoopSched = LoopScheduler::Default> struct DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle : public DeviceBatchedGemmSoftmaxGemmPermute; + MaskingSpec == MaskingSpecialization::MaskOutUpperTriangle, + D0sTransferSrcScalarPerVector>; // Argument // FIXME: constness @@ -530,6 +532,11 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle using D0DataType = remove_cvref_t>; // D0 pointer p_d0s_grid_(i) = static_cast(p_acc0_biases[i]); + // for check + d0s_nl_ns_lengths_strides_[i].push_back( + acc0_biases_gs_ms_ns_lengths[i][NumDimG + NumDimM]); + d0s_nl_ns_lengths_strides_[i].push_back( + acc0_biases_gs_ms_ns_strides[i][NumDimG + NumDimM]); }); if(GridwiseGemm::CheckValidity(a_grid_desc_ak0_m_ak1_, @@ -608,6 +615,7 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle std::vector b_nz_kz_strides_; std::vector b1_nz_kz_strides_; std::vector c_mz_gemm1nz_strides_; + std::array, NumD0Tensor> d0s_nl_ns_lengths_strides_; index_t batch_count_; ComputeBasePtrOfStridedBatch compute_base_ptr_of_batch_; @@ -715,8 +723,7 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle arg.Print(); #endif - if(!(ck::get_device_name() == "gfx908" || ck::get_device_name() == "gfx90a" || - ck::get_device_name() == "gfx940")) + if(!ck::is_xdl_supported()) { return false; } @@ -772,6 +779,18 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle { return false; } + for(int i = 0; i < NumD0Tensor; i++) + { + if(arg.d0s_nl_ns_lengths_strides_[i][1] == 1 && + arg.d0s_nl_ns_lengths_strides_[i][0] % D0sTransferSrcScalarPerVector != 0) + { + return false; + } + if(arg.d0s_nl_ns_lengths_strides_[i][1] != 1 && D0sTransferSrcScalarPerVector != 1) + { + return false; + } + } return GridwiseGemm::CheckValidity(arg.a_grid_desc_ak0_m_ak1_, arg.b_grid_desc_bk0_n_bk1_, diff --git a/include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_softmax_gemm_xdl_cshuffle.hpp b/include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_softmax_gemm_xdl_cshuffle.hpp index 84edde63eceb141198ed249879db48131ff14d5c..f12e05fd349e41b0b06bf48bea091751d6325d63 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_softmax_gemm_xdl_cshuffle.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_softmax_gemm_xdl_cshuffle.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #pragma once @@ -63,7 +63,7 @@ __global__ void const C0MatrixMask c0_matrix_mask) { #if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \ - defined(__gfx940__)) + defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__)) __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()]; const index_t num_blocks_per_batch = __builtin_amdgcn_readfirstlane(get_grid_size() / batch_count); @@ -613,8 +613,7 @@ struct DeviceBatchedGemmSoftmaxGemm_Xdl_CShuffle static bool IsSupportedArgument(const Argument& arg) { - if(!(ck::get_device_name() == "gfx908" || ck::get_device_name() == "gfx90a" || - ck::get_device_name() == "gfx940")) + if(!ck::is_xdl_supported()) { return false; } diff --git a/include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_xdl.hpp b/include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_xdl.hpp index d35f194171e710b490a70906ed2cf2079ab7b77a..303eba156e328e16ce3dec4a24f88021432bf477 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_xdl.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_xdl.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #pragma once @@ -45,75 +45,46 @@ namespace device { * realize BatchedGemm and GroupedGemm (and the corresponding GEMM fusion). * */ -template +template __global__ void #if CK_USE_LAUNCH_BOUNDS __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) #endif - kernel_batched_gemm_xdlops_v2r3( - const FloatAB* __restrict__ p_a_grid, - const FloatAB* __restrict__ p_b_grid, - FloatC* __restrict__ p_c_grid, - const index_t batch_count, - const AGridDesc_K0_M_K1 a_grid_desc_k0_m_k1, - const BGridDesc_K0_N_K1 b_grid_desc_k0_n_k1, - const CGridDesc_M0_N0_M1_N1_M2_M3_M4_N2 c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2, - const AElementwiseOperation a_element_op, - const BElementwiseOperation b_element_op, - const CElementwiseOperation c_element_op, - const ComputePtrOffsetOfBatch compute_ptr_offset_of_batch, - const Block2CTileMap block_2_ctile_map) + kernel_batched_gemm_xdlops_v2r3(const typename DeviceOp::Argument karg) { #if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \ - defined(__gfx940__)) + defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__)) const index_t num_blocks_per_batch = - __builtin_amdgcn_readfirstlane(get_grid_size() / batch_count); + __builtin_amdgcn_readfirstlane(get_grid_size() / karg.Batch); const index_t g_idx = __builtin_amdgcn_readfirstlane(get_block_1d_id() / num_blocks_per_batch); const long_index_t a_batch_offset = __builtin_amdgcn_readfirstlane( - static_cast(compute_ptr_offset_of_batch.GetAPtrOffset(g_idx))); + static_cast(karg.compute_ptr_offset_of_batch.GetAPtrOffset(g_idx))); const long_index_t b_batch_offset = __builtin_amdgcn_readfirstlane( - static_cast(compute_ptr_offset_of_batch.GetBPtrOffset(g_idx))); + static_cast(karg.compute_ptr_offset_of_batch.GetBPtrOffset(g_idx))); const long_index_t c_batch_offset = __builtin_amdgcn_readfirstlane( - static_cast(compute_ptr_offset_of_batch.GetCPtrOffset(g_idx))); + static_cast(karg.compute_ptr_offset_of_batch.GetCPtrOffset(g_idx))); __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()]; - GridwiseGemm::template Run(p_a_grid + a_batch_offset, - p_b_grid + b_batch_offset, - p_c_grid + c_batch_offset, + const auto a_grid_desc_k0_m_k1 = + amd_wave_read_first_lane(GridwiseGemm::MakeAGridDescriptor_K0_M_K1( + karg.M, karg.MPadded, karg.K, karg.K0, karg.StrideA)); + const auto b_grid_desc_k0_n_k1 = + amd_wave_read_first_lane(GridwiseGemm::MakeBGridDescriptor_K0_N_K1( + karg.K, karg.N, karg.NPadded, karg.K0, karg.StrideB)); + const auto c_grid_desc_m_n = amd_wave_read_first_lane(GridwiseGemm::MakeCGridDescriptor_M_N( + karg.M, karg.MPadded, karg.N, karg.NPadded, karg.StrideC)); + + GridwiseGemm::template Run(karg.p_a_grid + a_batch_offset, + karg.p_b_grid + b_batch_offset, + karg.p_c_grid + c_batch_offset, p_shared, a_grid_desc_k0_m_k1, b_grid_desc_k0_n_k1, - c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2, - a_element_op, - b_element_op, - c_element_op, - block_2_ctile_map); + c_grid_desc_m_n); #else - ignore = p_a_grid; - ignore = p_b_grid; - ignore = p_c_grid; - ignore = batch_count; - ignore = a_grid_desc_k0_m_k1; - ignore = b_grid_desc_k0_n_k1; - ignore = c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2; - ignore = a_element_op; - ignore = b_element_op; - ignore = c_element_op; - ignore = compute_ptr_offset_of_batch; - ignore = block_2_ctile_map; + ignore = karg; #endif } @@ -171,93 +142,6 @@ struct DeviceBatchedGemmXdl : public DeviceBatchedGemm{}; - static auto MakeAGridDescriptor_K0_M_K1(index_t M, index_t K, index_t StrideA) - { - assert(K % K1 == 0); - - const index_t K0 = K / K1; - - const auto a_grid_desc_m_k = [&]() { - if constexpr(is_same::value) - { - return make_naive_tensor_descriptor(make_tuple(M, K), make_tuple(StrideA, I1)); - } - else if constexpr(is_same::value) - { - return make_naive_tensor_descriptor(make_tuple(M, K), make_tuple(I1, StrideA)); - } - }(); - - const auto PadM = (MPerBlock - M % MPerBlock) % MPerBlock; - - const auto a_grid_desc_k0_mp_k1 = - transform_tensor_descriptor(a_grid_desc_m_k, - make_tuple(make_unmerge_transform(make_tuple(K0, K1Number)), - make_right_pad_transform(M, PadM)), - make_tuple(Sequence<1>{}, Sequence<0>{}), - make_tuple(Sequence<0, 2>{}, Sequence<1>{})); - - return a_grid_desc_k0_mp_k1; - } - - static auto MakeBGridDescriptor_K0_N_K1(index_t K, index_t N, index_t StrideB) - { - assert(K % K1 == 0); - - const index_t K0 = K / K1; - - const auto b_grid_desc_k_n = [&]() { - if constexpr(is_same::value) - { - return make_naive_tensor_descriptor(make_tuple(K, N), make_tuple(StrideB, I1)); - } - else if constexpr(is_same::value) - { - return make_naive_tensor_descriptor(make_tuple(K, N), make_tuple(I1, StrideB)); - } - }(); - - const auto PadN = (NPerBlock - N % NPerBlock) % NPerBlock; - - const auto b_grid_desc_k0_np_k1 = - transform_tensor_descriptor(b_grid_desc_k_n, - make_tuple(make_unmerge_transform(make_tuple(K0, K1Number)), - make_right_pad_transform(N, PadN)), - make_tuple(Sequence<0>{}, Sequence<1>{}), - make_tuple(Sequence<0, 2>{}, Sequence<1>{})); - - return b_grid_desc_k0_np_k1; - } - - static auto MakeCGridDescriptor_M_N(index_t M, index_t N, index_t StrideC) - { - const auto c_grid_desc_m_n = [&]() { - if constexpr(is_same::value) - { - return make_naive_tensor_descriptor(make_tuple(M, N), make_tuple(StrideC, I1)); - } - else if constexpr(is_same::value) - { - return make_naive_tensor_descriptor(make_tuple(M, N), make_tuple(I1, StrideC)); - } - }(); - - const auto PadM = (MPerBlock - M % MPerBlock) % MPerBlock; - const auto PadN = (NPerBlock - N % NPerBlock) % NPerBlock; - - const auto c_grid_desc_mp_np = transform_tensor_descriptor( - c_grid_desc_m_n, - make_tuple(make_right_pad_transform(M, PadM), make_right_pad_transform(N, PadN)), - make_tuple(Sequence<0>{}, Sequence<1>{}), - make_tuple(Sequence<0>{}, Sequence<1>{})); - - return c_grid_desc_mp_np; - } - - using AGridDesc_K0_M_K1 = decltype(MakeAGridDescriptor_K0_M_K1(1, 1, 1)); - using BGridDesc_K0_N_K1 = decltype(MakeBGridDescriptor_K0_N_K1(1, 1, 1)); - using CGridDesc_M_N = decltype(MakeCGridDescriptor_M_N(1, 1, 1)); - struct ComputePtrOffsetOfStridedBatch { ComputePtrOffsetOfStridedBatch(index_t BatchStrideA, @@ -289,121 +173,82 @@ struct DeviceBatchedGemmXdl : public DeviceBatchedGemm, - CThreadTransferSrcDstVectorDim, - CThreadTransferDstScalarPerVector, - NumGemmKPrefetchStage, - LoopSched, - PipelineVer>; - - using CGridDesc_M0_N0_M1_N1_M2_M3_M4_N2 = - decltype(GridwiseGemm::MakeCGridDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(CGridDesc_M_N{})); - using Block2CTileMap = typename GridwiseGemm::DefaultBlock2CTileMap; + using GridwiseGemm = GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3_ext< + BlockSize, + ADataType, // TODO: distinguish A/B datatype + AccDataType, + CDataType, + InMemoryDataOperationEnum::Set, + ALayout, + BLayout, + CLayout, + AElementwiseOperation, + BElementwiseOperation, + CElementwiseOperation, + GemmSpecialization::MNKPadding, + MPerBlock, + NPerBlock, + K0PerBlock, + MPerXDL, + NPerXDL, + K1, + MXdlPerWave, + NXdlPerWave, + ABlockTransferThreadClusterLengths_K0_M_K1, + ABlockTransferThreadClusterArrangeOrder, + ABlockTransferSrcAccessOrder, + ABlockTransferSrcVectorDim, + ABlockTransferSrcScalarPerVector, + ABlockTransferDstScalarPerVector_K1, + false, // AThreadTransferSrcResetCoordinateAfterRun, + ABlockLdsAddExtraM, + BBlockTransferThreadClusterLengths_K0_N_K1, + BBlockTransferThreadClusterArrangeOrder, + BBlockTransferSrcAccessOrder, + BBlockTransferSrcVectorDim, + BBlockTransferSrcScalarPerVector, + BBlockTransferDstScalarPerVector_K1, + false, // BThreadTransferSrcResetCoordinateAfterRun, + BBlockLdsAddExtraN, + Sequence<2, 3, 0, 1, 7, 5, 4, 6>, + CThreadTransferSrcDstVectorDim, + CThreadTransferDstScalarPerVector, + NumGemmKPrefetchStage, + LoopSched, + PipelineVer>; + + using Problem = typename GridwiseGemm::Problem; // Argument - struct Argument : public BaseArgument + struct Argument : public Problem, public BaseArgument { - Argument(const ADataType* p_a_grid, - const BDataType* p_b_grid, - CDataType* p_c_grid, - index_t M, - index_t N, - index_t K, - index_t StrideA, - index_t StrideB, - index_t StrideC, + Argument(const ADataType* p_a_grid_, + const BDataType* p_b_grid_, + CDataType* p_c_grid_, + index_t M_, + index_t N_, + index_t K_, + index_t StrideA_, + index_t StrideB_, + index_t StrideC_, index_t BatchStrideA, index_t BatchStrideB, index_t BatchStrideC, - index_t Batch, - index_t M01, - index_t N01, - AElementwiseOperation a_element_op, - BElementwiseOperation b_element_op, - CElementwiseOperation c_element_op) - : p_a_grid_{p_a_grid}, - p_b_grid_{p_b_grid}, - p_c_grid_{p_c_grid}, - Batch_(Batch), - a_grid_desc_k0_m_k1_{ - DeviceBatchedGemmXdl::MakeAGridDescriptor_K0_M_K1(M, K, StrideA)}, - b_grid_desc_k0_n_k1_{ - DeviceBatchedGemmXdl::MakeBGridDescriptor_K0_N_K1(K, N, StrideB)}, - c_grid_desc_m_n_{DeviceBatchedGemmXdl::MakeCGridDescriptor_M_N(M, N, StrideC)}, - c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_{}, - compute_ptr_offset_of_batch_{BatchStrideA, BatchStrideB, BatchStrideC}, - block_2_ctile_map_{ - GridwiseGemm::MakeDefaultBlock2CTileMap(c_grid_desc_m_n_, M01, N01)}, - M01_{M01}, - N01_{N01}, - a_element_op_{a_element_op}, - b_element_op_{b_element_op}, - c_element_op_{c_element_op}, - kraw_{K} + index_t Batch_) + : Problem{M_, N_, K_, StrideA_, StrideB_, StrideC_}, + p_a_grid{p_a_grid_}, + p_b_grid{p_b_grid_}, + p_c_grid{p_c_grid_}, + Batch(Batch_), + compute_ptr_offset_of_batch{BatchStrideA, BatchStrideB, BatchStrideC} { - if(GridwiseGemm::CheckValidity(a_grid_desc_k0_m_k1_, - b_grid_desc_k0_n_k1_, - c_grid_desc_m_n_, - block_2_ctile_map_)) - { - c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_ = - GridwiseGemm::MakeCGridDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(c_grid_desc_m_n_); - } } - // private: - const ADataType* p_a_grid_; - const BDataType* p_b_grid_; - CDataType* p_c_grid_; - index_t Batch_; - AGridDesc_K0_M_K1 a_grid_desc_k0_m_k1_; - BGridDesc_K0_N_K1 b_grid_desc_k0_n_k1_; - CGridDesc_M_N c_grid_desc_m_n_; - CGridDesc_M0_N0_M1_N1_M2_M3_M4_N2 c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_; - ComputePtrOffsetOfStridedBatch compute_ptr_offset_of_batch_; - Block2CTileMap block_2_ctile_map_; - index_t M01_; - index_t N01_; - AElementwiseOperation a_element_op_; - BElementwiseOperation b_element_op_; - CElementwiseOperation c_element_op_; - index_t kraw_; + const ADataType* p_a_grid; + const BDataType* p_b_grid; + CDataType* p_c_grid; + index_t Batch; + ComputePtrOffsetOfStridedBatch compute_ptr_offset_of_batch; }; // Invoker @@ -411,107 +256,39 @@ struct DeviceBatchedGemmXdl : public DeviceBatchedGemm 0) { - std::cout << "arg.a_grid_desc_k0_m_k1_{" << arg.a_grid_desc_k0_m_k1_.GetLength(I0) - << ", " << arg.a_grid_desc_k0_m_k1_.GetLength(I1) << ", " - << arg.a_grid_desc_k0_m_k1_.GetLength(I2) << "}" << std::endl; - - std::cout << "arg.b_grid_desc_k0_n_k1_{" << arg.b_grid_desc_k0_n_k1_.GetLength(I0) - << ", " << arg.b_grid_desc_k0_n_k1_.GetLength(I1) << ", " - << arg.b_grid_desc_k0_n_k1_.GetLength(I2) << "}" << std::endl; - - std::cout << "arg.c_grid_desc_m_n_{" << arg.c_grid_desc_m_n_.GetLength(I0) << ", " - << arg.c_grid_desc_m_n_.GetLength(I1) << "}" << std::endl; + karg.Print(); } -#endif - if(!GridwiseGemm::CheckValidity(arg.a_grid_desc_k0_m_k1_, - arg.b_grid_desc_k0_n_k1_, - arg.c_grid_desc_m_n_, - arg.block_2_ctile_map_)) + if(!GridwiseGemm::CheckValidity(karg)) { throw std::runtime_error( - "wrong! GridwiseBatchedGemm_km_kn_m0m1n0n1_xdlops_v2r3 has invalid setting"); + "wrong! GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3_ext has invalid setting"); } - const index_t grid_size = - arg.block_2_ctile_map_.CalculateGridSize(arg.c_grid_desc_m_n_) * arg.Batch_; - - const auto K = - arg.a_grid_desc_k0_m_k1_.GetLength(I0) * arg.a_grid_desc_k0_m_k1_.GetLength(I2); + auto [gdx, gdy, gdz] = GridwiseGemm::CalculateGridSize(karg.M, karg.N); + gdx *= karg.Batch; float ave_time = 0; - if(GridwiseGemm::CalculateHasMainKBlockLoop(K)) + if(GridwiseGemm::CalculateHasMainKBlockLoop(karg.K)) { - const auto kernel = kernel_batched_gemm_xdlops_v2r3< - GridwiseGemm, - ADataType, // TODO: distiguish A/B datatype - CDataType, - remove_reference_t, - remove_reference_t, - remove_reference_t, - AElementwiseOperation, - BElementwiseOperation, - CElementwiseOperation, - ComputePtrOffsetOfStridedBatch, - remove_reference_t, - true>; - - ave_time = launch_and_time_kernel(stream_config, - kernel, - dim3(grid_size), - dim3(BlockSize), - 0, - arg.p_a_grid_, - arg.p_b_grid_, - arg.p_c_grid_, - arg.Batch_, - arg.a_grid_desc_k0_m_k1_, - arg.b_grid_desc_k0_n_k1_, - arg.c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_, - arg.a_element_op_, - arg.b_element_op_, - arg.c_element_op_, - arg.compute_ptr_offset_of_batch_, - arg.block_2_ctile_map_); + const auto kernel = + kernel_batched_gemm_xdlops_v2r3; + + ave_time = launch_and_time_kernel( + stream_config, kernel, dim3(gdx, gdy, gdz), dim3(BlockSize), 0, karg); } else { - const auto kernel = kernel_batched_gemm_xdlops_v2r3< - GridwiseGemm, - ADataType, // TODO: distiguish A/B datatype - CDataType, - remove_reference_t, - remove_reference_t, - remove_reference_t, - AElementwiseOperation, - BElementwiseOperation, - CElementwiseOperation, - ComputePtrOffsetOfStridedBatch, - remove_reference_t, - false>; - - ave_time = launch_and_time_kernel(stream_config, - kernel, - dim3(grid_size), - dim3(BlockSize), - 0, - arg.p_a_grid_, - arg.p_b_grid_, - arg.p_c_grid_, - arg.Batch_, - arg.a_grid_desc_k0_m_k1_, - arg.b_grid_desc_k0_n_k1_, - arg.c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_, - arg.a_element_op_, - arg.b_element_op_, - arg.c_element_op_, - arg.compute_ptr_offset_of_batch_, - arg.block_2_ctile_map_); + const auto kernel = + kernel_batched_gemm_xdlops_v2r3; + + ave_time = launch_and_time_kernel( + stream_config, kernel, dim3(gdx, gdy, gdz), dim3(BlockSize), 0, karg); } return ave_time; @@ -531,17 +308,14 @@ struct DeviceBatchedGemmXdl : public DeviceBatchedGemm(static_cast(p_a), static_cast(p_b), @@ -619,12 +385,7 @@ struct DeviceBatchedGemmXdl : public DeviceBatchedGemm" << " NumGemmKPrefetchStage: " << NumGemmKPrefetchStage << ", " diff --git a/include/ck/tensor_operation/gpu/device/impl/device_batchnorm_backward_impl.hpp b/include/ck/tensor_operation/gpu/device/impl/device_batchnorm_backward_impl.hpp index ab16a757f699bb4e4ba781873bbf87538633d9c1..f46237e005642d433c29f5ab0f027a224043dbef 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_batchnorm_backward_impl.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_batchnorm_backward_impl.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/include/ck/tensor_operation/gpu/device/impl/device_batchnorm_forward_impl.hpp b/include/ck/tensor_operation/gpu/device/impl/device_batchnorm_forward_impl.hpp index 5a16ff765bbc5c0bc1b68d80262c88a0c8f84416..ad8e7956033db31db2992d6fa9f86cf16a29b150 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_batchnorm_forward_impl.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_batchnorm_forward_impl.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #pragma once @@ -10,12 +10,14 @@ #include "ck/tensor_operation/gpu/device/device_batchnorm_forward.hpp" #include "ck/tensor_operation/gpu/device/impl/device_reduce_common.hpp" #include "ck/tensor_operation/gpu/device/welford_helper.hpp" +#include "ck/tensor_operation/gpu/grid/batchnorm_multiblock/gridwise_multiblock_batchnorm_forward.hpp" #include "ck/tensor_operation/gpu/grid/batchnorm_multiblock/gridwise_multiblock_welford_first_half.hpp" -#include "ck/tensor_operation/gpu/grid/batchnorm_multiblock/gridwise_multiblock_welford_second_half_batchnorm_forward_final.hpp" +#include "ck/tensor_operation/gpu/grid/batchnorm_multiblock/gridwise_multiblock_welford_second_half_batchnorm_forward_final_obsolete.hpp" #include "ck/tensor_operation/gpu/grid/gridwise_batchnorm_forward_blockwise_welford.hpp" #include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" #include "ck/host_utility/device_prop.hpp" #include "ck/host_utility/kernel_launch.hpp" +#include "ck/host_utility/hip_check_error.hpp" namespace ck { namespace tensor_operation { @@ -114,8 +116,8 @@ struct DeviceBatchNormFwdImpl : public DeviceBatchNormFwdinvariant_length_ * pArg_->blkGroupSize_ * sizeof(int32_t) + 64; + + // workspace for barrier objects, each barrier object consists of two integers + // TODO: allocate barrier object memory globally to reuse it by other operators + workspace_size += (pArg_->invariant_length_ + M_BlockTileSize - 1) / M_BlockTileSize * + sizeof(int) * 2; } return (workspace_size); @@ -353,7 +362,6 @@ struct DeviceBatchNormFwdImpl : public DeviceBatchNormFwdblkGroupSize_ > 1) { - // setup buffer used for intermediate welford mean pArg_->workspace_mean_ = static_cast(pArg_->p_workspace_); @@ -374,6 +382,18 @@ struct DeviceBatchNormFwdImpl : public DeviceBatchNormFwdworkspace_count_ = reinterpret_cast(pArg_->workspace_variance_) + variance_space_sz; + + index_t count_space_sz = + pArg_->invariant_length_ * pArg_->blkGroupSize_ * sizeof(int32_t); + + count_space_sz = math::integer_least_multiple(count_space_sz, 64); + + pArg_->control_ = reinterpret_cast(pArg_->workspace_count_) + count_space_sz; + + index_t control_space_sz = (pArg_->invariant_length_ + M_BlockTileSize - 1) / + M_BlockTileSize * sizeof(int) * 2; + + hip_check_error(hipMemset(pArg_->control_, 0, control_space_sz)); }; }; @@ -402,6 +422,32 @@ struct DeviceBatchNormFwdImpl : public DeviceBatchNormFwd; + using GridwiseMultiblockWelfordFirstHalf_ = GridwiseMultiblockWelfordFirstHalf; - index_t numMeanVarCountBlockTileIteration = - (arg.blkGroupSize_ + KThreadClusterSize - 1) / KThreadClusterSize; - - const auto kern_multiblock_welford_first_half = - kernel_multiblock_welford_first_half; - - const auto kern_welford_second_half_batchnorm_forward_final = - kernel_welford_second_half_batchnorm_forward_final< - GridwiseWelfordSecondHalfBatchNormForwardFinal_, - XDataType, - YDataType, - AccDataType, - ScaleDataType, - BiasDataType, - MeanVarDataType, - YElementwiseOp, - XYGridDesc_M_K, - MeanVarCountGridDesc_M_K, - ScaleBiasMeanVarGridDesc_M, - ScaleBiasMeanVarGridDesc_M>; - - avg_time += - launch_and_time_kernel(stream_config, - kern_multiblock_welford_first_half, - dim3(arg.gridSize_), - dim3(BlockSize), - 0, - arg.x_grid_desc_m_k_, - mean_var_count_grid_desc_m_g, - get_reduce_count_per_thread, - arg.numBlockTileIteration_, - arg.p_x_, - static_cast(arg.workspace_mean_), - static_cast(arg.workspace_variance_), - static_cast(arg.workspace_count_)); - - avg_time += - launch_and_time_kernel(stream_config, - kern_welford_second_half_batchnorm_forward_final, - dim3(arg.gridSize_), - dim3(BlockSize), - 0, - arg.x_grid_desc_m_k_, - arg.y_grid_desc_m_k_, - mean_var_count_grid_desc_m_k, - arg.scale_grid_desc_m_, - arg.bias_grid_desc_m_, - arg.mean_var_grid_desc_m_, - arg.blkGroupSize_, - arg.numBlockTileIteration_, - numMeanVarCountBlockTileIteration, - arg.epsilon_, - static_cast(arg.workspace_mean_), - static_cast(arg.workspace_variance_), - static_cast(arg.workspace_count_), - arg.p_x_, - arg.p_scale_, - arg.p_bias_, - arg.y_elementwise_op_, - arg.p_y_, - arg.updateMovingAverage_, - arg.averageFactor_, - arg.resultRunningMean_, - arg.resultRunningVariance_, - arg.saveMeanInvVariance_, - arg.resultSaveMean_, - arg.resultSaveInvVariance_); + // It is found that: + // 1) gfx1030 does not support the GLC enabled vector load/store, so using the + // two-kernel method for gfx1030 + // 2) Profiler on gfx908 could hang even though it works when running examples + // 3) Single-kernel method works on gfx1100, but the performance it not better + // than two-kernel method (due to more warps participating the barrier) + if(ck::get_device_name() == "gfx90a") + { + const auto kern_multiblock_batchnorm_fwd_ = + kernel_multiblock_batchnorm_forward; + + avg_time += launch_and_time_kernel( + stream_config, + kern_multiblock_batchnorm_fwd_, + dim3(arg.gridSize_), + dim3(BlockSize), + 0, + arg.x_grid_desc_m_k_, + arg.y_grid_desc_m_k_, + mean_var_count_grid_desc_m_g, // for writing to mean/variance/count + // workspace by multiple workgroups + mean_var_count_grid_desc_m_k, // for reading from mean/variance/count + // workspace by each workgroup + arg.scale_grid_desc_m_, + arg.bias_grid_desc_m_, + arg.mean_var_grid_desc_m_, + get_reduce_count_per_thread, + arg.numBlockTileIteration_, + arg.epsilon_, + arg.p_x_, + static_cast(arg.workspace_mean_), + static_cast(arg.workspace_variance_), + static_cast(arg.workspace_count_), + static_cast(arg.control_), + arg.p_scale_, + arg.p_bias_, + arg.y_elementwise_op_, + arg.p_y_, + arg.updateMovingAverage_, // true or false + arg.averageFactor_, + arg.resultRunningMean_, + arg.resultRunningVariance_, + arg.saveMeanInvVariance_, // true or false + arg.resultSaveMean_, + arg.resultSaveInvVariance_); + } + else + { + const auto kern_multiblock_welford_first_half = + kernel_multiblock_welford_first_half; + + const auto kern_welford_second_half_batchnorm_forward_final = + kernel_welford_second_half_batchnorm_forward_final< + GridwiseWelfordSecondHalfBatchNormForwardFinal_, + XDataType, + YDataType, + AccDataType, + ScaleDataType, + BiasDataType, + MeanVarDataType, + YElementwiseOp, + XYGridDesc_M_K, + MeanVarCountGridDesc_M_K, + ScaleBiasMeanVarGridDesc_M, + ScaleBiasMeanVarGridDesc_M>; + + avg_time += launch_and_time_kernel( + stream_config, + kern_multiblock_welford_first_half, + dim3(arg.gridSize_), + dim3(BlockSize), + 0, + arg.x_grid_desc_m_k_, + mean_var_count_grid_desc_m_g, + get_reduce_count_per_thread, + arg.numBlockTileIteration_, + arg.p_x_, + static_cast(arg.workspace_mean_), + static_cast(arg.workspace_variance_), + static_cast(arg.workspace_count_)); + + avg_time += launch_and_time_kernel( + stream_config, + kern_welford_second_half_batchnorm_forward_final, + dim3(arg.gridSize_), + dim3(BlockSize), + 0, + arg.x_grid_desc_m_k_, + arg.y_grid_desc_m_k_, + mean_var_count_grid_desc_m_k, + arg.scale_grid_desc_m_, + arg.bias_grid_desc_m_, + arg.mean_var_grid_desc_m_, + arg.blkGroupSize_, + arg.numBlockTileIteration_, + arg.epsilon_, + static_cast(arg.workspace_mean_), + static_cast(arg.workspace_variance_), + static_cast(arg.workspace_count_), + arg.p_x_, + arg.p_scale_, + arg.p_bias_, + arg.y_elementwise_op_, + arg.p_y_, + arg.updateMovingAverage_, + arg.averageFactor_, + arg.resultRunningMean_, + arg.resultRunningVariance_, + arg.saveMeanInvVariance_, + arg.resultSaveMean_, + arg.resultSaveInvVariance_); + }; } else { diff --git a/include/ck/tensor_operation/gpu/device/impl/device_batchnorm_forward_impl_obsolete.hpp b/include/ck/tensor_operation/gpu/device/impl/device_batchnorm_forward_impl_obsolete.hpp new file mode 100644 index 0000000000000000000000000000000000000000..b826793c27be946a99d0d96a99d745fd6c618822 --- /dev/null +++ b/include/ck/tensor_operation/gpu/device/impl/device_batchnorm_forward_impl_obsolete.hpp @@ -0,0 +1,714 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include +#include + +#include "ck/utility/reduction_operator.hpp" +#include "ck/tensor_operation/gpu/device/device_batchnorm_forward.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_reduce_common.hpp" +#include "ck/tensor_operation/gpu/device/welford_helper.hpp" +#include "ck/tensor_operation/gpu/grid/batchnorm_multiblock/gridwise_multiblock_welford_first_half.hpp" +#include "ck/tensor_operation/gpu/grid/batchnorm_multiblock/gridwise_multiblock_welford_second_half_batchnorm_forward_final_obsolete.hpp" +#include "ck/tensor_operation/gpu/grid/gridwise_batchnorm_forward_blockwise_welford.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" +#include "ck/host_utility/device_prop.hpp" +#include "ck/host_utility/kernel_launch.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { + +template +struct DeviceBatchNormFwdImpl : public DeviceBatchNormFwd +{ + static_assert(Rank <= 6, "Bigger Rank size is not supported!"); + static_assert(BlockSize == MThreadClusterSize * KThreadClusterSize, + "Invalid thread cluster size assignments!"); + + static_assert((XSrcYDstVectorDim == 0 && MThreadSliceSize % XSrcVectorSize == 0) || + (XSrcYDstVectorDim == 1 && KThreadSliceSize % XSrcVectorSize == 0), + "Invalid thread slice sizes and/or vector sizes configuration, please check!"); + + static constexpr index_t NumInvariantDim = Rank - NumBatchNormReduceDim; + + static constexpr index_t M_BlockTileSize = MThreadClusterSize * MThreadSliceSize; + static constexpr index_t K_BlockTileSize = KThreadClusterSize * KThreadSliceSize; + + static auto MakeXY2dDescriptor(const std::array& xyLengths, + const std::array& xyStrides, + int blkGroupSize, + int numBlockTileIteration) + { + const auto tupleXYLengths = + generate_tuple([&](auto I) { return xyLengths[I]; }, Number{}); + const auto tupleXYStrides = + generate_tuple([&](auto I) { return xyStrides[I]; }, Number{}); + + const auto raw_grid_desc = make_naive_tensor_descriptor(tupleXYLengths, tupleXYStrides); + + const auto grid_desc_m_k = [&]() { + using InvariantDims = typename arithmetic_sequence_gen<0, NumInvariantDim, 1>::type; + using ReduceDims = typename arithmetic_sequence_gen::type; + + const auto reduceDimLengths = + generate_tuple([&](auto I) { return xyLengths[NumInvariantDim + I]; }, + Number{}); + const auto invariantDimLengths = + generate_tuple([&](auto I) { return xyLengths[I]; }, Number{}); + + return transform_tensor_descriptor(raw_grid_desc, + make_tuple(make_merge_transform(invariantDimLengths), + make_merge_transform(reduceDimLengths)), + make_tuple(InvariantDims{}, ReduceDims{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + }(); + + const auto invariantLength = grid_desc_m_k.GetLength(Number<0>{}); + const auto reduceLength = grid_desc_m_k.GetLength(Number<1>{}); + + const int workSizePerBlock = K_BlockTileSize * numBlockTileIteration; + const auto mPad = + math::integer_least_multiple(invariantLength, M_BlockTileSize) - invariantLength; + const auto kPad = workSizePerBlock * blkGroupSize - reduceLength; + + auto grid_desc_m_k_padded = + transform_tensor_descriptor(grid_desc_m_k, + make_tuple(make_right_pad_transform(invariantLength, mPad), + make_right_pad_transform(reduceLength, kPad)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + + return (grid_desc_m_k_padded); + }; + + static auto MakeMeanVarCountOutputMG2dDescriptor(int invariantLength, int blkGroupSize) + { + const auto grid_desc_m_g = make_naive_tensor_descriptor( + make_tuple(invariantLength, blkGroupSize), make_tuple(1, invariantLength)); + + const auto mPad = + math::integer_least_multiple(invariantLength, M_BlockTileSize) - invariantLength; + + auto grid_desc_m_g_padded = + transform_tensor_descriptor(grid_desc_m_g, + make_tuple(make_right_pad_transform(invariantLength, mPad), + make_pass_through_transform(blkGroupSize)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + + return (grid_desc_m_g_padded); + }; + + static auto MakeMeanVarCountInputMK2dDescriptor(int invariantLength, int blkGroupSize) + { + const auto reduceLength = blkGroupSize; + const auto grid_desc_m_k = make_naive_tensor_descriptor( + make_tuple(invariantLength, reduceLength), make_tuple(1, invariantLength)); + + const auto mPad = + math::integer_least_multiple(invariantLength, M_BlockTileSize) - invariantLength; + const auto kPad = + math::integer_least_multiple(reduceLength, KThreadClusterSize) - reduceLength; + + auto grid_desc_m_k_padded = + transform_tensor_descriptor(grid_desc_m_k, + make_tuple(make_right_pad_transform(invariantLength, mPad), + make_right_pad_transform(reduceLength, kPad)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + + return (grid_desc_m_k_padded); + }; + + static auto + MakeScaleBiasMeanVar1dDescriptor(const std::array& lengths, + const std::array& strides) + { + const auto tupleLengths = + generate_tuple([&](auto I) { return lengths[I]; }, Number{}); + const auto tupleStrides = + generate_tuple([&](auto I) { return strides[I]; }, Number{}); + + auto raw_grid_desc = make_naive_tensor_descriptor(tupleLengths, tupleStrides); + + auto grid_desc_m = transform_tensor_descriptor( + raw_grid_desc, + make_tuple(make_merge_transform(tupleLengths)), + make_tuple(typename arithmetic_sequence_gen<0, NumInvariantDim, 1>::type{}), + make_tuple(Sequence<0>{})); + + const auto invariantLength = grid_desc_m.GetLength(Number<0>{}); + + const auto mPad = + math::integer_least_multiple(invariantLength, M_BlockTileSize) - invariantLength; + + auto grid_desc_m_padded = + transform_tensor_descriptor(grid_desc_m, + make_tuple(make_right_pad_transform(invariantLength, mPad)), + make_tuple(Sequence<0>{}), + make_tuple(Sequence<0>{})); + return (grid_desc_m_padded); + }; + + using XYGridDesc_M_K = decltype(MakeXY2dDescriptor({1}, {1}, 1, 1)); + using ScaleBiasMeanVarGridDesc_M = decltype(MakeScaleBiasMeanVar1dDescriptor({1}, {1})); + + struct Argument : public BaseArgument + { + Argument(const std::array xyLengths, + const std::array xStrides, + const std::array yStrides, + const std::array reduceDims, + const std::array bnScaleBiasMeanVarLengths, + const std::array bnScaleStrides, + const std::array bnBiasStrides, + const std::array bnMeanVarStrides, + const XDataType* p_x, + const ScaleDataType* p_scale, + const BiasDataType* p_bias, + const YElementwiseOp y_elementwise_op, + double epsilon, + YDataType* p_y, + MeanVarDataType* resultSaveMean, + MeanVarDataType* resultSaveInvVariance, + double averageFactor, + MeanVarDataType* resultRunningMean, + MeanVarDataType* resultRunningVariance) + : bnScaleBiasMeanVarLengths_(bnScaleBiasMeanVarLengths), + bnScaleStrides_(bnScaleStrides), + bnBiasStrides_(bnBiasStrides), + bnMeanVarStrides_(bnMeanVarStrides), + p_x_(p_x), + p_scale_(p_scale), + p_bias_(p_bias), + y_elementwise_op_(y_elementwise_op), + p_y_(p_y), + resultSaveMean_(resultSaveMean), + resultSaveInvVariance_(resultSaveInvVariance), + resultRunningMean_(resultRunningMean), + resultRunningVariance_(resultRunningVariance) + { + xyLengths_ = + shuffle_tensor_dimensions(xyLengths, reduceDims); + xStrides_ = + shuffle_tensor_dimensions(xStrides, reduceDims); + yStrides_ = + shuffle_tensor_dimensions(yStrides, reduceDims); + + std::tie(invariant_length_, reduce_length_) = + get_2d_lengths(xyLengths_); + + epsilon_ = type_convert(epsilon); + averageFactor_ = type_convert(averageFactor); + + updateMovingAverage_ = + (resultRunningMean != nullptr && resultRunningVariance != nullptr); + saveMeanInvVariance_ = (resultSaveMean != nullptr && resultSaveInvVariance_ != nullptr); + + if(UseMultiblockInK) + { + int iterations = 1; + while(true) + { + int testBlkGroupSize = (reduce_length_ + (K_BlockTileSize * iterations) - 1) / + (K_BlockTileSize * iterations); + + // we want the blkGroupSize be not more than 16 + if(testBlkGroupSize <= 16) + break; + + iterations++; + }; + + blkGroupSize_ = (reduce_length_ + (K_BlockTileSize * iterations) - 1) / + (K_BlockTileSize * iterations); + + numBlockTileIteration_ = iterations; + } + else + { + blkGroupSize_ = 1; + numBlockTileIteration_ = (reduce_length_ + K_BlockTileSize - 1) / K_BlockTileSize; + }; + + gridSize_ = (invariant_length_ + M_BlockTileSize - 1) / M_BlockTileSize * blkGroupSize_; + + x_grid_desc_m_k_ = + MakeXY2dDescriptor(xyLengths_, xStrides_, blkGroupSize_, numBlockTileIteration_); + y_grid_desc_m_k_ = + MakeXY2dDescriptor(xyLengths_, yStrides_, blkGroupSize_, numBlockTileIteration_); + scale_grid_desc_m_ = + MakeScaleBiasMeanVar1dDescriptor(bnScaleBiasMeanVarLengths, bnScaleStrides_); + bias_grid_desc_m_ = + MakeScaleBiasMeanVar1dDescriptor(bnScaleBiasMeanVarLengths, bnBiasStrides_); + mean_var_grid_desc_m_ = + MakeScaleBiasMeanVar1dDescriptor(bnScaleBiasMeanVarLengths, bnMeanVarStrides_); + } + + AccDataType epsilon_; + AccDataType averageFactor_; + + bool updateMovingAverage_; + bool saveMeanInvVariance_; + + std::array xyLengths_; + std::array xStrides_; + std::array yStrides_; + + std::array bnScaleBiasMeanVarLengths_; + std::array bnScaleStrides_; + std::array bnBiasStrides_; + std::array bnMeanVarStrides_; + + const XDataType* p_x_; + const ScaleDataType* p_scale_; + const BiasDataType* p_bias_; + const YElementwiseOp y_elementwise_op_; + YDataType* p_y_; + + MeanVarDataType* resultSaveMean_; + MeanVarDataType* resultSaveInvVariance_; + + MeanVarDataType* resultRunningMean_; + MeanVarDataType* resultRunningVariance_; + + long_index_t invariant_length_; + long_index_t reduce_length_; + + int blkGroupSize_; + int numBlockTileIteration_; + size_t gridSize_; + + XYGridDesc_M_K x_grid_desc_m_k_; + XYGridDesc_M_K y_grid_desc_m_k_; + ScaleBiasMeanVarGridDesc_M scale_grid_desc_m_; + ScaleBiasMeanVarGridDesc_M bias_grid_desc_m_; + ScaleBiasMeanVarGridDesc_M mean_var_grid_desc_m_; + + void* workspace_mean_; + void* workspace_variance_; + void* workspace_count_; + }; + + size_t GetWorkSpaceSize(const BaseArgument* pArg) const override + { + const Argument* pArg_ = dynamic_cast(pArg); + + size_t workspace_size = 0; + + if(UseMultiblockInK && pArg_->blkGroupSize_ > 1) + { + // workspace for welford intermediate mean + workspace_size += + pArg_->invariant_length_ * pArg_->blkGroupSize_ * sizeof(MeanVarDataType) + 64; + + // workspace for welford intermediate variance + workspace_size += + pArg_->invariant_length_ * pArg_->blkGroupSize_ * sizeof(MeanVarDataType) + 64; + + // workspace for welford intermediate count + workspace_size += + pArg_->invariant_length_ * pArg_->blkGroupSize_ * sizeof(int32_t) + 64; + } + + return (workspace_size); + }; + + void SetWorkSpacePointer(BaseArgument* pArg, void* p_workspace) const override + { + Argument* pArg_ = dynamic_cast(pArg); + + pArg_->p_workspace_ = p_workspace; + + if(UseMultiblockInK && pArg_->blkGroupSize_ > 1) + { + + // setup buffer used for intermediate welford mean + pArg_->workspace_mean_ = static_cast(pArg_->p_workspace_); + + index_t mean_space_sz = + pArg_->invariant_length_ * pArg_->blkGroupSize_ * sizeof(MeanVarDataType); + + mean_space_sz = math::integer_least_multiple(mean_space_sz, 64); + + // setup buffer used for intermediate welford varirance + pArg_->workspace_variance_ = + reinterpret_cast(pArg_->workspace_mean_) + mean_space_sz; + + index_t variance_space_sz = + pArg_->invariant_length_ * pArg_->blkGroupSize_ * sizeof(MeanVarDataType); + + variance_space_sz = math::integer_least_multiple(variance_space_sz, 64); + + // setup buffer used for intermediate welfor count + pArg_->workspace_count_ = + reinterpret_cast(pArg_->workspace_variance_) + variance_space_sz; + }; + }; + + struct Invoker : public BaseInvoker + { + float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{}) + { + float avg_time = 0; + + if(UseMultiblockInK && arg.blkGroupSize_ > 1) + { + using GetReduceCountPerThreadFunctor = + GetReduceCountPerThreadForMultiblockWelford; + + GetReduceCountPerThreadFunctor get_reduce_count_per_thread( + arg.blkGroupSize_, arg.numBlockTileIteration_, arg.reduce_length_); + + const auto mean_var_count_grid_desc_m_g = + DeviceBatchNormFwdImpl::MakeMeanVarCountOutputMG2dDescriptor( + arg.invariant_length_, arg.blkGroupSize_); + + const auto mean_var_count_grid_desc_m_k = + DeviceBatchNormFwdImpl::MakeMeanVarCountInputMK2dDescriptor( + arg.invariant_length_, arg.blkGroupSize_); + + using MeanVarCountGridDesc_M_G = decltype(mean_var_count_grid_desc_m_g); + using MeanVarCountGridDesc_M_K = decltype(mean_var_count_grid_desc_m_k); + + using GridwiseMultiblockWelfordFirstHalf_ = + GridwiseMultiblockWelfordFirstHalf; + + using GridwiseWelfordSecondHalfBatchNormForwardFinal_ = + GridwiseWelfordSecondHalfBatchNormForwardFinal; + + const auto kern_multiblock_welford_first_half = + kernel_multiblock_welford_first_half; + + const auto kern_welford_second_half_batchnorm_forward_final = + kernel_welford_second_half_batchnorm_forward_final< + GridwiseWelfordSecondHalfBatchNormForwardFinal_, + XDataType, + YDataType, + AccDataType, + ScaleDataType, + BiasDataType, + MeanVarDataType, + YElementwiseOp, + XYGridDesc_M_K, + MeanVarCountGridDesc_M_K, + ScaleBiasMeanVarGridDesc_M, + ScaleBiasMeanVarGridDesc_M>; + + avg_time += + launch_and_time_kernel(stream_config, + kern_multiblock_welford_first_half, + dim3(arg.gridSize_), + dim3(BlockSize), + 0, + arg.x_grid_desc_m_k_, + mean_var_count_grid_desc_m_g, + get_reduce_count_per_thread, + arg.numBlockTileIteration_, + arg.p_x_, + static_cast(arg.workspace_mean_), + static_cast(arg.workspace_variance_), + static_cast(arg.workspace_count_)); + + avg_time += + launch_and_time_kernel(stream_config, + kern_welford_second_half_batchnorm_forward_final, + dim3(arg.gridSize_), + dim3(BlockSize), + 0, + arg.x_grid_desc_m_k_, + arg.y_grid_desc_m_k_, + mean_var_count_grid_desc_m_k, + arg.scale_grid_desc_m_, + arg.bias_grid_desc_m_, + arg.mean_var_grid_desc_m_, + arg.blkGroupSize_, + arg.numBlockTileIteration_, + arg.epsilon_, + static_cast(arg.workspace_mean_), + static_cast(arg.workspace_variance_), + static_cast(arg.workspace_count_), + arg.p_x_, + arg.p_scale_, + arg.p_bias_, + arg.y_elementwise_op_, + arg.p_y_, + arg.updateMovingAverage_, + arg.averageFactor_, + arg.resultRunningMean_, + arg.resultRunningVariance_, + arg.saveMeanInvVariance_, + arg.resultSaveMean_, + arg.resultSaveInvVariance_); + } + else + { + using GetReduceCountPerThreadFunctor = + GetReduceCountPerThreadForBlockwiseWelford; + + GetReduceCountPerThreadFunctor get_reduce_count_per_thread( + arg.numBlockTileIteration_, arg.reduce_length_); + + using GridwiseBatchNormForwardWithBlockwiseWelford_ = + GridwiseBatchNormForwardWithBlockwiseWelford; + + const auto kern_batchnorm_fwd = kernel_batchnorm_forward_with_blockwise_welford< + GridwiseBatchNormForwardWithBlockwiseWelford_, + XDataType, + YDataType, + AccDataType, + ScaleDataType, + BiasDataType, + MeanVarDataType, + YElementwiseOp, + XYGridDesc_M_K, + ScaleBiasMeanVarGridDesc_M, + ScaleBiasMeanVarGridDesc_M, + GetReduceCountPerThreadFunctor>; + + avg_time += launch_and_time_kernel(stream_config, + kern_batchnorm_fwd, + dim3(arg.gridSize_), + dim3(BlockSize), + 0, + arg.x_grid_desc_m_k_, + arg.y_grid_desc_m_k_, + arg.scale_grid_desc_m_, + arg.bias_grid_desc_m_, + arg.mean_var_grid_desc_m_, + get_reduce_count_per_thread, + arg.numBlockTileIteration_, + arg.epsilon_, + arg.p_x_, + arg.p_scale_, + arg.p_bias_, + arg.y_elementwise_op_, + arg.p_y_, + arg.updateMovingAverage_, // true or false + arg.averageFactor_, + arg.resultRunningMean_, + arg.resultRunningVariance_, + arg.saveMeanInvVariance_, // true or false + arg.resultSaveMean_, + arg.resultSaveInvVariance_); + }; + + return (avg_time); + }; + + float Run(const BaseArgument* pArg, + const StreamConfig& stream_config = StreamConfig{}) override + { + return Run(*dynamic_cast(pArg), stream_config); + }; + }; + + bool IsSupportedArgument(const BaseArgument* pArg) override + { + const Argument* pArg_ = dynamic_cast(pArg); + + if constexpr(XSrcYDstVectorDim == 0) + { + if(pArg_->xStrides_[NumInvariantDim - 1] != 1 || + pArg_->yStrides_[NumInvariantDim - 1] != 1) + return false; + + if(pArg_->xyLengths_[NumInvariantDim - 1] % XSrcVectorSize != 0 || + pArg_->xyLengths_[NumInvariantDim - 1] % YDstVectorSize != 0) + return false; + } + else + { + if(pArg_->xStrides_[Rank - 1] != 1 || pArg_->yStrides_[Rank - 1] != 1) + return false; + + if(pArg_->xyLengths_[Rank - 1] % XSrcVectorSize != 0 || + pArg_->xyLengths_[Rank - 1] % YDstVectorSize != 0) + return false; + }; + + if(pArg_->bnScaleStrides_[NumInvariantDim - 1] != 1 && ScaleSrcVectorSize != 1) + return false; + if(pArg_->bnBiasStrides_[NumInvariantDim - 1] != 1 && BiasSrcVectorSize != 1) + return false; + + if(pArg_->bnScaleBiasMeanVarLengths_[NumInvariantDim - 1] % ScaleSrcVectorSize != 0) + return false; + if(pArg_->bnScaleBiasMeanVarLengths_[NumInvariantDim - 1] % BiasSrcVectorSize != 0) + return false; + + if(pArg_->bnMeanVarStrides_[NumInvariantDim - 1] != 1 && MeanVarSrcDstVectorSize != 1) + return false; + + if(pArg_->bnScaleBiasMeanVarLengths_[NumInvariantDim - 1] % MeanVarSrcDstVectorSize != 0) + return false; + + bool is_valid = true; + + static_for<0, NumInvariantDim, 1>{}([&](auto I) { + if(pArg_->xyLengths_[I] != pArg_->bnScaleBiasMeanVarLengths_[I]) + is_valid = false; + }); + + if(!is_valid) + return false; + + return true; + }; + + std::unique_ptr MakeArgumentPointer( + const std::array xyLengths, + const std::array xStrides, + const std::array yStrides, + const std::array reduceDims, + const std::array bnScaleBiasMeanVarLengths, + const std::array bnScaleStrides, + const std::array bnBiasStrides, + const std::array bnMeanVarStrides, + const void* p_x, + const void* p_scale, + const void* p_bias, + double epsilon, + const YElementwiseOp y_elementwise_op, + void* p_y, + void* resultSaveMean, + void* resultSaveInvVariance, + double averageFactor, + void* resultRunningMean, + void* resultRunningVariance) override + { + return std::make_unique(xyLengths, + xStrides, + yStrides, + reduceDims, + bnScaleBiasMeanVarLengths, + bnScaleStrides, + bnBiasStrides, + bnMeanVarStrides, + static_cast(p_x), + static_cast(p_scale), + static_cast(p_bias), + y_elementwise_op, + epsilon, + static_cast(p_y), + static_cast(resultSaveMean), + static_cast(resultSaveInvVariance), + averageFactor, + static_cast(resultRunningMean), + static_cast(resultRunningVariance)); + }; + + std::unique_ptr MakeInvokerPointer() override + { + return std::make_unique(); + }; + + std::string GetTypeString() const override + { + auto str = std::stringstream(); + + // clang-format off + str << "DeviceBatchNormFwdImpl<" << BlockSize << ","; + str << "M_C" << MThreadClusterSize << "_S" << MThreadSliceSize << ","; + str << "K_C" << KThreadClusterSize << "_S" << KThreadSliceSize << ","; + str << "XSrcYDstVectorDim_" << XSrcYDstVectorDim << ","; + str << "VectorSize_X" << XSrcVectorSize << "_scale_" << ScaleSrcVectorSize << "_bias_" << BiasSrcVectorSize << "_mean_var_" << MeanVarSrcDstVectorSize << "_Y" << YDstVectorSize << ">"; + // clang-format on + + return str.str(); + } +}; + +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/include/ck/tensor_operation/gpu/device/impl/device_cgemm_4gemm_xdl_cshuffle.hpp b/include/ck/tensor_operation/gpu/device/impl/device_cgemm_4gemm_xdl_cshuffle.hpp index 0c9cce97fb7c32885d0ea761eb3d6f970a9a0327..a095521161c76d7bc831c76987a07691c415895c 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_cgemm_4gemm_xdl_cshuffle.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_cgemm_4gemm_xdl_cshuffle.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #pragma once @@ -123,7 +123,8 @@ struct DeviceCGemm_4Gemm_Xdl_CShuffle ALayout, BLayout, CLayout, - ADataType, // TODO: distinguish A/B datatype + ADataType, + BDataType, GemmAccDataType, CShuffleDataType, CDataType, @@ -284,8 +285,11 @@ struct DeviceCGemm_4Gemm_Xdl_CShuffle if(GridwiseGemm::CalculateHasMainKBlockLoop(K)) { - const auto kernel = - kernel_gemm_xdl_cshuffle_v1; + const auto kernel = kernel_gemm_xdl_cshuffle_v1; ave_time += launch_and_time_kernel(stream_config, kernel, @@ -357,8 +361,11 @@ struct DeviceCGemm_4Gemm_Xdl_CShuffle } else { - const auto kernel = - kernel_gemm_xdl_cshuffle_v1; + const auto kernel = kernel_gemm_xdl_cshuffle_v1; ave_time += launch_and_time_kernel(stream_config, kernel, @@ -448,6 +455,11 @@ struct DeviceCGemm_4Gemm_Xdl_CShuffle static bool IsSupportedArgument(const Argument& arg) { + if(!ck::is_xdl_supported()) + { + return false; + } + return GridwiseGemm::CheckValidity(arg); } diff --git a/include/ck/tensor_operation/gpu/device/impl/device_contraction_multiple_d_xdl_cshuffle.hpp b/include/ck/tensor_operation/gpu/device/impl/device_contraction_multiple_d_xdl_cshuffle.hpp index 1eaffe705a74c9cdffc9a29f912f134c2bffbd97..5ae06836fab74e27e70ea5f04c2ea1f66f8540b8 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_contraction_multiple_d_xdl_cshuffle.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_contraction_multiple_d_xdl_cshuffle.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #pragma once @@ -53,7 +53,7 @@ __global__ void const Block2ETileMap block_2_etile_map) { #if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \ - defined(__gfx940__)) + defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__)) __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()]; GridwiseGemm::template Run(p_a_grid, @@ -310,9 +310,13 @@ struct DeviceContractionMultipleD_Xdl_CShuffle using DsGridDesc_M_N = remove_cvref_t; using EGridDesc_M_N = decltype(MakeEGridDescriptor_M_N({}, {})); + using ComputeDataType = ADataType; + // GridwiseGemm using GridwiseGemm = GridwiseGemmMultipleD_xdl_cshuffle< ADataType, // TODO: distinguish A/B datatype + BDataType, + ComputeDataType, AccDataType, CShuffleDataType, DsDataType, @@ -355,14 +359,18 @@ struct DeviceContractionMultipleD_Xdl_CShuffle LoopSched>; // desc for blockwise copy - using AGridDesc_AK0_M_AK1 = remove_cvref_t; - using BGridDesc_BK0_N_BK1 = remove_cvref_t; - using DsGridDesc_MBlock_MPerBlock_NBlock_NPerBlock = remove_cvref_t; - using EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock = remove_cvref_t; + using AGridDesc_AK0_M_AK1 = + remove_cvref_t; + using BGridDesc_BK0_N_BK1 = + remove_cvref_t; + using DsGridDesc_MBlock_MPerBlock_NBlock_NPerBlock = remove_cvref_t< + decltype(GridwiseGemm::MakeDsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( + DsGridDesc_M_N{}))>; + using EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock = + remove_cvref_t; // block-to-e-tile map using Block2ETileMap = @@ -582,8 +590,7 @@ struct DeviceContractionMultipleD_Xdl_CShuffle static bool IsSupportedArgument(const Argument& arg) { - if(!(ck::get_device_name() == "gfx908" || ck::get_device_name() == "gfx90a" || - ck::get_device_name() == "gfx940")) + if(!ck::is_xdl_supported()) { return false; } diff --git a/include/ck/tensor_operation/gpu/device/impl/device_conv2d_backward_weight_xdl_c_shuffle_nhwc_kyxc_nhwk.hpp b/include/ck/tensor_operation/gpu/device/impl/device_conv2d_backward_weight_xdl_c_shuffle_nhwc_kyxc_nhwk.hpp index b65afce8df5834d14f4ee9b3b19730e11db43bee..23440e24f6828c55b0160c3cb06ee667da9d3f60 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_conv2d_backward_weight_xdl_c_shuffle_nhwc_kyxc_nhwk.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_conv2d_backward_weight_xdl_c_shuffle_nhwc_kyxc_nhwk.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #pragma once @@ -532,11 +532,12 @@ struct DeviceConv2dBwdWeightXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_ float ave_time = 0; const auto Run = [&](const auto& kernel) { - hipGetErrorString(hipMemset( + hipGetErrorString(hipMemsetAsync( arg.p_c_grid_, 0, arg.c_grid_desc_mblock_mperblock_nblock_nperblock_.GetElementSpaceSize() * - sizeof(CDataType))); + sizeof(CDataType), + stream_config.stream_id_)); ave_time = launch_and_time_kernel(stream_config, @@ -649,6 +650,11 @@ struct DeviceConv2dBwdWeightXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_ static bool IsSupportedArgument(const Argument& arg) { + if(!ck::is_xdl_supported()) + { + return false; + } + // vector load A/B matrix from global memory if(!(ABlockTransferSrcVectorDim == 2 && BBlockTransferSrcVectorDim == 2 && arg.Conv_K_ % ABlockTransferSrcScalarPerVector == 0 && diff --git a/include/ck/tensor_operation/gpu/device/impl/device_conv2d_bwd_data_xdl_nhwc_kyxc_nhwk.hpp b/include/ck/tensor_operation/gpu/device/impl/device_conv2d_bwd_data_xdl_nhwc_kyxc_nhwk.hpp index ea3020663a88c1ce565beab17313e72928f50461..e22c5a2aa514d4ace3b05074be06a539a7b54433 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_conv2d_bwd_data_xdl_nhwc_kyxc_nhwk.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_conv2d_bwd_data_xdl_nhwc_kyxc_nhwk.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #pragma once @@ -379,9 +379,6 @@ struct DeviceConv2dBwdDataXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K AccDataType, CDataType, InMemoryDataOperationEnum::Set, - AGridDesc_K0_M_K1, - BGridDesc_K0_N_K1, - CGridDesc_M_N, InElementwiseOperation, WeiElementwiseOperation, OutElementwiseOperation, @@ -428,20 +425,10 @@ struct DeviceConv2dBwdDataXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K std::vector conv_filter_strides, std::vector conv_filter_dilations, std::vector input_left_pads, - std::vector input_right_pads, - ck::index_t M01, - ck::index_t N01, - InElementwiseOperation in_element_op, - WeiElementwiseOperation wei_element_op, - OutElementwiseOperation out_element_op) + std::vector input_right_pads) : p_a_grid_{p_out_grid}, p_b_grid_{p_wei_grid}, p_c_grid_{p_in_grid}, - M01_{M01}, - N01_{N01}, - a_element_op_{out_element_op}, - b_element_op_{wei_element_op}, - c_element_op_{in_element_op}, Conv_N_{N}, Conv_K_{K}, Conv_C_{C}, @@ -495,18 +482,6 @@ struct DeviceConv2dBwdDataXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K a_grid_desc_k0_m_k1_container_.push_back(descs[I0]); b_grid_desc_k0_n_k1_container_.push_back(descs[I1]); c_grid_desc_m_n_container_.push_back(descs[I2]); - - auto block_2_ctile_map = - GridwiseGemm::MakeDefaultBlock2CTileMap(descs[I2], M01, N01); - - if(GridwiseGemm::CheckValidity( - descs[I0], descs[I1], descs[I2], block_2_ctile_map)) - { - c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_container_.push_back( - GridwiseGemm::MakeCGridDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(descs[I2])); - - block_2_ctile_map_container_.push_back(block_2_ctile_map); - } } } } @@ -517,14 +492,6 @@ struct DeviceConv2dBwdDataXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K std::vector a_grid_desc_k0_m_k1_container_; std::vector b_grid_desc_k0_n_k1_container_; std::vector c_grid_desc_m_n_container_; - std::vector - c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_container_; - std::vector block_2_ctile_map_container_; - index_t M01_; - index_t N01_; - OutElementwiseOperation a_element_op_; - WeiElementwiseOperation b_element_op_; - InElementwiseOperation c_element_op_; // for checking IsSupportedArgument() index_t Conv_N_; index_t Conv_K_; @@ -567,103 +534,68 @@ struct DeviceConv2dBwdDataXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K << arg.c_grid_desc_m_n_container_[i].GetLength(I0) << ", " << arg.c_grid_desc_m_n_container_[i].GetLength(I1) << "}" << std::endl; - - std::cout << "arg.c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_container_( " - << arg.c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_container_[i].GetLength(I0) - << ", " - << arg.c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_container_[i].GetLength(I1) - << ", " - << arg.c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_container_[i].GetLength(I2) - << ", " - << arg.c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_container_[i].GetLength(I3) - << ", " - << arg.c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_container_[i].GetLength(I4) - << ", " - << arg.c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_container_[i].GetLength(I5) - << " ) " << std::endl; } #endif if(!GridwiseGemm::CheckValidity(arg.a_grid_desc_k0_m_k1_container_[i], arg.b_grid_desc_k0_n_k1_container_[i], - arg.c_grid_desc_m_n_container_[i], - arg.block_2_ctile_map_container_[i])) + arg.c_grid_desc_m_n_container_[i])) { throw std::runtime_error( "wrong! GridwiseGemm_km_kn_m0m1n0n1_xdlops_v3r1 has invalid setting"); } - const index_t grid_size = arg.block_2_ctile_map_container_[i].CalculateGridSize( - arg.c_grid_desc_m_n_container_[i]); + const auto [gdx, gdy, gdz] = + GridwiseGemm::CalculateGridSize(arg.c_grid_desc_m_n_container_[i]); const auto K = arg.a_grid_desc_k0_m_k1_container_[i].GetLength(I0) * arg.a_grid_desc_k0_m_k1_container_[i].GetLength(I2); if(GridwiseGemm::CalculateHasMainKBlockLoop(K)) { - const auto kernel = kernel_gemm_xdlops_v2r3< - GridwiseGemm, - ADataType, // TODO: distiguish A/B datatype - CDataType, - remove_reference_t, - remove_reference_t, - remove_reference_t< - typename GridwiseGemm::CGridDesc_M0_N0_M1_N1_M2_M3_M4_N2>, - OutElementwiseOperation, - WeiElementwiseOperation, - InElementwiseOperation, - remove_reference_t, - true>; - - ave_time += launch_and_time_kernel( - stream_config, - kernel, - dim3(grid_size), - dim3(BlockSize), - 0, - arg.p_a_grid_, - arg.p_b_grid_, - arg.p_c_grid_, - arg.a_grid_desc_k0_m_k1_container_[i], - arg.b_grid_desc_k0_n_k1_container_[i], - arg.c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_container_[i], - arg.a_element_op_, - arg.b_element_op_, - arg.c_element_op_, - arg.block_2_ctile_map_container_[i]); + const auto kernel = + kernel_gemm_xdlops_v2r3; + + ave_time += launch_and_time_kernel(stream_config, + kernel, + dim3(gdx, gdy, gdz), + dim3(BlockSize), + 0, + arg.p_a_grid_, + arg.p_b_grid_, + arg.p_c_grid_, + arg.a_grid_desc_k0_m_k1_container_[i], + arg.b_grid_desc_k0_n_k1_container_[i], + arg.c_grid_desc_m_n_container_[i]); } else { - const auto kernel = kernel_gemm_xdlops_v2r3< - GridwiseGemm, - ADataType, // TODO: distiguish A/B datatype - CDataType, - remove_reference_t, - remove_reference_t, - remove_reference_t< - typename GridwiseGemm::CGridDesc_M0_N0_M1_N1_M2_M3_M4_N2>, - OutElementwiseOperation, - WeiElementwiseOperation, - InElementwiseOperation, - remove_reference_t, - false>; - - ave_time += launch_and_time_kernel( - stream_config, - kernel, - dim3(grid_size), - dim3(BlockSize), - 0, - arg.p_a_grid_, - arg.p_b_grid_, - arg.p_c_grid_, - arg.a_grid_desc_k0_m_k1_container_[i], - arg.b_grid_desc_k0_n_k1_container_[i], - arg.c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_container_[i], - arg.a_element_op_, - arg.b_element_op_, - arg.c_element_op_, - arg.block_2_ctile_map_container_[i]); + const auto kernel = + kernel_gemm_xdlops_v2r3; + + ave_time += launch_and_time_kernel(stream_config, + kernel, + dim3(gdx, gdy, gdz), + dim3(BlockSize), + 0, + arg.p_a_grid_, + arg.p_b_grid_, + arg.p_c_grid_, + arg.a_grid_desc_k0_m_k1_container_[i], + arg.b_grid_desc_k0_n_k1_container_[i], + arg.c_grid_desc_m_n_container_[i]); } } return ave_time; @@ -684,6 +616,11 @@ struct DeviceConv2dBwdDataXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K static bool IsSupportedArgument(const Argument& arg) { + if(!ck::is_xdl_supported()) + { + return false; + } + if constexpr(ConvBackwardDataSpecialization == ConvolutionBackwardDataSpecialization::Filter1x1Stride1Pad0) { @@ -716,8 +653,7 @@ struct DeviceConv2dBwdDataXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K { if(!GridwiseGemm::CheckValidity(arg.a_grid_desc_k0_m_k1_container_[i], arg.b_grid_desc_k0_n_k1_container_[i], - arg.c_grid_desc_m_n_container_[i], - arg.block_2_ctile_map_container_[i])) + arg.c_grid_desc_m_n_container_[i])) { return false; } @@ -742,10 +678,7 @@ struct DeviceConv2dBwdDataXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K std::vector conv_filter_strides, std::vector conv_filter_dilations, std::vector input_left_pads, - std::vector input_right_pads, - InElementwiseOperation in_element_op, - WeiElementwiseOperation wei_element_op, - OutElementwiseOperation out_element_op) + std::vector input_right_pads) { return Argument{p_in_grid, p_wei_grid, @@ -759,12 +692,7 @@ struct DeviceConv2dBwdDataXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K conv_filter_strides, conv_filter_dilations, input_left_pads, - input_right_pads, - 1, - 1, - in_element_op, - wei_element_op, - out_element_op}; + input_right_pads}; } static auto MakeInvoker() { return Invoker{}; } @@ -783,9 +711,9 @@ struct DeviceConv2dBwdDataXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K std::vector conv_filter_dilations, std::vector input_left_pads, std::vector input_right_pads, - InElementwiseOperation in_element_op, - WeiElementwiseOperation wei_element_op, - OutElementwiseOperation out_element_op) override + InElementwiseOperation, + WeiElementwiseOperation, + OutElementwiseOperation) override { return std::make_unique(static_cast(p_in_grid), static_cast(p_wei_grid), @@ -799,12 +727,7 @@ struct DeviceConv2dBwdDataXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K conv_filter_strides, conv_filter_dilations, input_left_pads, - input_right_pads, - 1, - 1, - in_element_op, - wei_element_op, - out_element_op); + input_right_pads); } std::unique_ptr MakeInvokerPointer() override diff --git a/include/ck/tensor_operation/gpu/device/impl/device_conv2d_fwd_xdl_c_shuffle_bias_activation_add_nhwc_kyxc_nhwk.hpp b/include/ck/tensor_operation/gpu/device/impl/device_conv2d_fwd_xdl_c_shuffle_bias_activation_add_nhwc_kyxc_nhwk.hpp index e7e2bf33540fd8dd05ba1062d32241d39ff371cd..c9e8940edc4030be672d7c9def1225c9f650eee5 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_conv2d_fwd_xdl_c_shuffle_bias_activation_add_nhwc_kyxc_nhwk.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_conv2d_fwd_xdl_c_shuffle_bias_activation_add_nhwc_kyxc_nhwk.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #pragma once @@ -810,6 +810,11 @@ struct static bool IsSupportedArgument(const Argument& arg) { + if(!ck::is_xdl_supported()) + { + return false; + } + if constexpr(ConvForwardSpecialization == ConvolutionForwardSpecialization::Filter1x1Stride1Pad0) { diff --git a/include/ck/tensor_operation/gpu/device/impl/device_conv2d_fwd_xdl_c_shuffle_bias_activation_nhwc_kyxc_nhwk.hpp b/include/ck/tensor_operation/gpu/device/impl/device_conv2d_fwd_xdl_c_shuffle_bias_activation_nhwc_kyxc_nhwk.hpp index 6c4957b9b17f68c3a5d33470582d1d67907a8e85..28fceb428eff34628cc2f5c680bd9a45c0b856e6 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_conv2d_fwd_xdl_c_shuffle_bias_activation_nhwc_kyxc_nhwk.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_conv2d_fwd_xdl_c_shuffle_bias_activation_nhwc_kyxc_nhwk.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #pragma once @@ -767,6 +767,11 @@ struct DeviceConv2dFwdXdl_C_Shuffle_Bias_Activation_Input_N_Hi_Wi_C_Weight_K_Y_X static bool IsSupportedArgument(const Argument& arg) { + if(!ck::is_xdl_supported()) + { + return false; + } + if constexpr(ConvForwardSpecialization == ConvolutionForwardSpecialization::Filter1x1Stride1Pad0) { diff --git a/include/ck/tensor_operation/gpu/device/impl/device_conv2d_fwd_xdl_c_shuffle_nhwc_kyxc_nhwk.hpp b/include/ck/tensor_operation/gpu/device/impl/device_conv2d_fwd_xdl_c_shuffle_nhwc_kyxc_nhwk.hpp index 027f1a1954da565721a1bf072f58e8feddcf29ca..ca291d3b11f5273caaf7900bf349eddceb108658 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_conv2d_fwd_xdl_c_shuffle_nhwc_kyxc_nhwk.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_conv2d_fwd_xdl_c_shuffle_nhwc_kyxc_nhwk.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #pragma once @@ -741,6 +741,11 @@ struct DeviceConv2dFwdXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_W static bool IsSupportedArgument(const Argument& arg) { + if(!ck::is_xdl_supported()) + { + return false; + } + if constexpr(ConvForwardSpecialization == ConvolutionForwardSpecialization::Filter1x1Stride1Pad0) { diff --git a/include/ck/tensor_operation/gpu/device/impl/device_conv2d_fwd_xdl_nhwc_kyxc_nhwk.hpp b/include/ck/tensor_operation/gpu/device/impl/device_conv2d_fwd_xdl_nhwc_kyxc_nhwk.hpp index 6278220c2f9dee047f632acef0e0d2f911156fa6..ef94120f4e89f199018fc62c5af481fc0aca4f83 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_conv2d_fwd_xdl_nhwc_kyxc_nhwk.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_conv2d_fwd_xdl_nhwc_kyxc_nhwk.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #pragma once @@ -329,9 +329,6 @@ struct DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K AccDataType, CDataType, InMemoryDataOperationEnum::Set, - AGridDesc_K0_M_K1, - BGridDesc_K0_N_K1, - CGridDesc_M_N, InElementwiseOperation, WeiElementwiseOperation, OutElementwiseOperation, @@ -378,25 +375,13 @@ struct DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K std::vector conv_filter_strides, std::vector conv_filter_dilations, std::vector input_left_pads, - std::vector input_right_pads, - ck::index_t M01, - ck::index_t N01, - InElementwiseOperation in_element_op, - WeiElementwiseOperation wei_element_op, - OutElementwiseOperation out_element_op) + std::vector input_right_pads) : p_a_grid_{p_in_grid}, p_b_grid_{p_wei_grid}, p_c_grid_{p_out_grid}, a_grid_desc_k0_m_k1_{}, b_grid_desc_k0_n_k1_{}, c_grid_desc_m_n_{}, - c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_{}, - block_2_ctile_map_{}, - M01_{M01}, - N01_{N01}, - in_element_op_{in_element_op}, - wei_element_op_{wei_element_op}, - out_element_op_{out_element_op}, Conv_N_{N}, Conv_K_{K}, Conv_C_{C}, @@ -420,17 +405,6 @@ struct DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K a_grid_desc_k0_m_k1_ = descs[I0]; b_grid_desc_k0_n_k1_ = descs[I1]; c_grid_desc_m_n_ = descs[I2]; - block_2_ctile_map_ = - GridwiseGemm::MakeDefaultBlock2CTileMap(c_grid_desc_m_n_, M01, N01); - - if(GridwiseGemm::CheckValidity(a_grid_desc_k0_m_k1_, - b_grid_desc_k0_n_k1_, - c_grid_desc_m_n_, - block_2_ctile_map_)) - { - c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_ = - GridwiseGemm::MakeCGridDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(c_grid_desc_m_n_); - } } // private: @@ -440,14 +414,6 @@ struct DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K AGridDesc_K0_M_K1 a_grid_desc_k0_m_k1_; BGridDesc_K0_N_K1 b_grid_desc_k0_n_k1_; CGridDesc_M_N c_grid_desc_m_n_; - typename GridwiseGemm::CGridDesc_M0_N0_M1_N1_M2_M3_M4_N2 - c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_; - typename GridwiseGemm::DefaultBlock2CTileMap block_2_ctile_map_; - index_t M01_; - index_t N01_; - InElementwiseOperation in_element_op_; - WeiElementwiseOperation wei_element_op_; - OutElementwiseOperation out_element_op_; // for checking IsSupportedArgument() index_t Conv_N_; index_t Conv_K_; @@ -479,17 +445,14 @@ struct DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K << arg.c_grid_desc_m_n_.GetLength(I1) << "}" << std::endl; } #endif - if(!GridwiseGemm::CheckValidity(arg.a_grid_desc_k0_m_k1_, - arg.b_grid_desc_k0_n_k1_, - arg.c_grid_desc_m_n_, - arg.block_2_ctile_map_)) + if(!GridwiseGemm::CheckValidity( + arg.a_grid_desc_k0_m_k1_, arg.b_grid_desc_k0_n_k1_, arg.c_grid_desc_m_n_)) { throw std::runtime_error( "wrong! GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3 has invalid setting"); } - const index_t grid_size = - arg.block_2_ctile_map_.CalculateGridSize(arg.c_grid_desc_m_n_); + const auto [gdx, gdy, gdz] = GridwiseGemm::CalculateGridSize(arg.c_grid_desc_m_n_); const auto K = arg.a_grid_desc_k0_m_k1_.GetLength(I0) * arg.a_grid_desc_k0_m_k1_.GetLength(I2); @@ -498,22 +461,18 @@ struct DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K if(GridwiseGemm::CalculateHasMainKBlockLoop(K)) { - const auto kernel = kernel_gemm_xdlops_v2r3< - GridwiseGemm, - ADataType, // TODO: distiguish A/B datatype - CDataType, - remove_reference_t, - remove_reference_t, - remove_reference_t, - InElementwiseOperation, - WeiElementwiseOperation, - OutElementwiseOperation, - remove_reference_t, - true>; + const auto kernel = + kernel_gemm_xdlops_v2r3; ave_time = launch_and_time_kernel(stream_config, kernel, - dim3(grid_size), + dim3(gdx, gdy, gdz), dim3(BlockSize), 0, arg.p_a_grid_, @@ -521,30 +480,22 @@ struct DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K arg.p_c_grid_, arg.a_grid_desc_k0_m_k1_, arg.b_grid_desc_k0_n_k1_, - arg.c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_, - arg.in_element_op_, - arg.wei_element_op_, - arg.out_element_op_, - arg.block_2_ctile_map_); + arg.c_grid_desc_m_n_); } else { - const auto kernel = kernel_gemm_xdlops_v2r3< - GridwiseGemm, - ADataType, // TODO: distiguish A/B datatype - CDataType, - remove_reference_t, - remove_reference_t, - remove_reference_t, - InElementwiseOperation, - WeiElementwiseOperation, - OutElementwiseOperation, - remove_reference_t, - false>; + const auto kernel = + kernel_gemm_xdlops_v2r3; ave_time = launch_and_time_kernel(stream_config, kernel, - dim3(grid_size), + dim3(gdx, gdy, gdz), dim3(BlockSize), 0, arg.p_a_grid_, @@ -552,11 +503,7 @@ struct DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K arg.p_c_grid_, arg.a_grid_desc_k0_m_k1_, arg.b_grid_desc_k0_n_k1_, - arg.c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_, - arg.in_element_op_, - arg.wei_element_op_, - arg.out_element_op_, - arg.block_2_ctile_map_); + arg.c_grid_desc_m_n_); } return ave_time; @@ -577,6 +524,11 @@ struct DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K static bool IsSupportedArgument(const Argument& arg) { + if(!ck::is_xdl_supported()) + { + return false; + } + if constexpr(ConvForwardSpecialization == ConvolutionForwardSpecialization::Filter1x1Stride1Pad0) { @@ -616,10 +568,8 @@ struct DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K } // Gridwise GEMM size - return GridwiseGemm::CheckValidity(arg.a_grid_desc_k0_m_k1_, - arg.b_grid_desc_k0_n_k1_, - arg.c_grid_desc_m_n_, - arg.block_2_ctile_map_); + return GridwiseGemm::CheckValidity( + arg.a_grid_desc_k0_m_k1_, arg.b_grid_desc_k0_n_k1_, arg.c_grid_desc_m_n_); } bool IsSupportedArgument(const BaseArgument* p_arg) override @@ -639,10 +589,7 @@ struct DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K std::vector conv_filter_strides, std::vector conv_filter_dilations, std::vector input_left_pads, - std::vector input_right_pads, - InElementwiseOperation in_element_op, - WeiElementwiseOperation wei_element_op, - OutElementwiseOperation out_element_op) + std::vector input_right_pads) { return Argument{p_in_grid, p_wei_grid, @@ -656,12 +603,7 @@ struct DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K conv_filter_strides, conv_filter_dilations, input_left_pads, - input_right_pads, - 1, - 1, - in_element_op, - wei_element_op, - out_element_op}; + input_right_pads}; } static auto MakeInvoker() { return Invoker{}; } @@ -680,9 +622,9 @@ struct DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K std::vector conv_filter_dilations, std::vector input_left_pads, std::vector input_right_pads, - InElementwiseOperation in_element_op, - WeiElementwiseOperation wei_element_op, - OutElementwiseOperation out_element_op) override + InElementwiseOperation, + WeiElementwiseOperation, + OutElementwiseOperation) override { return std::make_unique(static_cast(p_in_grid), static_cast(p_wei_grid), @@ -696,12 +638,7 @@ struct DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K conv_filter_strides, conv_filter_dilations, input_left_pads, - input_right_pads, - 1, - 1, - in_element_op, - wei_element_op, - out_element_op); + input_right_pads); } std::unique_ptr MakeInvokerPointer() override diff --git a/include/ck/tensor_operation/gpu/device/impl/device_conv3d_fwd_naive_ndhwc_kzyxc_ndhwk.hpp b/include/ck/tensor_operation/gpu/device/impl/device_conv3d_fwd_naive_ndhwc_kzyxc_ndhwk.hpp index f69d8f18ae036b8e7cd060a80a0cca12bfc2a050..cd89f3232cb94e9ce9a68c06587758ac9574816f 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_conv3d_fwd_naive_ndhwc_kzyxc_ndhwk.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_conv3d_fwd_naive_ndhwc_kzyxc_ndhwk.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #ifndef DEVICE_CONV3D_FWD_NAIVE_HPP #define DEVICE_CONV3D_FWD_NAIVE_HPP diff --git a/include/ck/tensor_operation/gpu/device/impl/device_conv3d_fwd_xdl_ndhwc_kzyxc_ndhwk.hpp b/include/ck/tensor_operation/gpu/device/impl/device_conv3d_fwd_xdl_ndhwc_kzyxc_ndhwk.hpp index d52879cd904f05d165557c08ae6d123fbfa18434..a8e586b20c38a40438636c5f9cafafb5c34409a9 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_conv3d_fwd_xdl_ndhwc_kzyxc_ndhwk.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_conv3d_fwd_xdl_ndhwc_kzyxc_ndhwk.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #ifndef DEVICE_CONV3D_FWD_XDL_HPP #define DEVICE_CONV3D_FWD_XDL_HPP @@ -56,7 +56,7 @@ __global__ void const Block2CTileMap block_2_ctile_map) { #if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \ - defined(__gfx940__)) + defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__)) const index_t num_blocks_per_batch = __builtin_amdgcn_readfirstlane(get_grid_size() / num_batches); const index_t g_idx = __builtin_amdgcn_readfirstlane(get_block_1d_id() / num_blocks_per_batch); @@ -524,6 +524,11 @@ struct DeviceConv3dFwdXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_Wo_ static bool IsSupportedArgument(const Argument& arg) { + if(!ck::is_xdl_supported()) + { + return false; + } + return GridwiseGemm::CheckValidity(arg.a_grid_desc_k0_m_k1_, arg.b_grid_desc_k0_n_k1_, arg.c_grid_desc_m_n_, diff --git a/include/ck/tensor_operation/gpu/device/impl/device_convnd_bwd_data_nwc_kxc_nwk_dl.hpp b/include/ck/tensor_operation/gpu/device/impl/device_convnd_bwd_data_nwc_kxc_nwk_dl.hpp index aff25aa7c66758a0636227e0494cf0f54359dc44..3178f73f4b530e67b399cd0ae71064a4b66fa837 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_convnd_bwd_data_nwc_kxc_nwk_dl.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_convnd_bwd_data_nwc_kxc_nwk_dl.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/include/ck/tensor_operation/gpu/device/impl/device_convnd_bwd_data_nwc_kxc_nwk_xdl.hpp b/include/ck/tensor_operation/gpu/device/impl/device_convnd_bwd_data_nwc_kxc_nwk_xdl.hpp index 1fd4b76cec4cad5fd23b6a64166d210abdcb902a..ee3f0cea1b37f1de5ce3f53d6252e62024c8ddf8 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_convnd_bwd_data_nwc_kxc_nwk_xdl.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_convnd_bwd_data_nwc_kxc_nwk_xdl.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #pragma once @@ -980,9 +980,6 @@ struct DeviceConvNdBwdDataNwcKxcNwk_Xdl AccDataType, CDataType, InMemoryDataOperationEnum::Set, - AGridDesc_K0_M_K1, - BGridDesc_K0_N_K1, - CGridDesc_M_N, InElementwiseOperation, WeiElementwiseOperation, OutElementwiseOperation, @@ -1029,20 +1026,10 @@ struct DeviceConvNdBwdDataNwcKxcNwk_Xdl std::vector conv_filter_strides, std::vector conv_filter_dilations, std::vector input_left_pads, - std::vector input_right_pads, - ck::index_t M01, - ck::index_t N01, - InElementwiseOperation in_element_op, - WeiElementwiseOperation wei_element_op, - OutElementwiseOperation out_element_op) + std::vector input_right_pads) : p_a_grid_{p_out_grid}, p_b_grid_{p_wei_grid}, p_c_grid_{p_in_grid}, - M01_{M01}, - N01_{N01}, - a_element_op_{out_element_op}, - b_element_op_{wei_element_op}, - c_element_op_{in_element_op}, Conv_N_{N}, Conv_K_{K}, Conv_C_{C}, @@ -1092,17 +1079,6 @@ struct DeviceConvNdBwdDataNwcKxcNwk_Xdl a_grid_desc_k0_m_k1_container_.push_back(descs[I0]); b_grid_desc_k0_n_k1_container_.push_back(descs[I1]); c_grid_desc_m_n_container_.push_back(descs[I2]); - - auto block_2_ctile_map = - GridwiseGemm::MakeDefaultBlock2CTileMap(descs[I2], M01_, N01_); - - if(GridwiseGemm::CheckValidity(descs[I0], descs[I1], descs[I2], block_2_ctile_map)) - { - c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_container_.push_back( - GridwiseGemm::MakeCGridDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(descs[I2])); - - block_2_ctile_map_container_.push_back(block_2_ctile_map); - } } } template ::type = false> @@ -1150,18 +1126,6 @@ struct DeviceConvNdBwdDataNwcKxcNwk_Xdl a_grid_desc_k0_m_k1_container_.push_back(descs[I0]); b_grid_desc_k0_n_k1_container_.push_back(descs[I1]); c_grid_desc_m_n_container_.push_back(descs[I2]); - - auto block_2_ctile_map = - GridwiseGemm::MakeDefaultBlock2CTileMap(descs[I2], M01_, N01_); - - if(GridwiseGemm::CheckValidity( - descs[I0], descs[I1], descs[I2], block_2_ctile_map)) - { - c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_container_.push_back( - GridwiseGemm::MakeCGridDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(descs[I2])); - - block_2_ctile_map_container_.push_back(block_2_ctile_map); - } } } } @@ -1218,19 +1182,6 @@ struct DeviceConvNdBwdDataNwcKxcNwk_Xdl a_grid_desc_k0_m_k1_container_.push_back(descs[I0]); b_grid_desc_k0_n_k1_container_.push_back(descs[I1]); c_grid_desc_m_n_container_.push_back(descs[I2]); - - auto block_2_ctile_map = - GridwiseGemm::MakeDefaultBlock2CTileMap(descs[I2], M01_, N01_); - - if(GridwiseGemm::CheckValidity( - descs[I0], descs[I1], descs[I2], block_2_ctile_map)) - { - c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_container_.push_back( - GridwiseGemm::MakeCGridDescriptor_M0_N0_M1_N1_M2_M3_M4_N2( - descs[I2])); - - block_2_ctile_map_container_.push_back(block_2_ctile_map); - } } } } @@ -1242,11 +1193,6 @@ struct DeviceConvNdBwdDataNwcKxcNwk_Xdl std::vector a_grid_desc_k0_m_k1_container_; std::vector b_grid_desc_k0_n_k1_container_; std::vector c_grid_desc_m_n_container_; - std::vector - c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_container_; - std::vector block_2_ctile_map_container_; - index_t M01_; - index_t N01_; OutElementwiseOperation a_element_op_; WeiElementwiseOperation b_element_op_; InElementwiseOperation c_element_op_; @@ -1276,123 +1222,84 @@ struct DeviceConvNdBwdDataNwcKxcNwk_Xdl { #if DEBUG_LOG { - std::cout << "arg.a_grid_desc_k0_m_k1_container_{" + std::cout << "arg.a_grid_desc_k0_m_k1{" << arg.a_grid_desc_k0_m_k1_container_[i].GetLength(I0) << ", " << arg.a_grid_desc_k0_m_k1_container_[i].GetLength(I1) << ", " << arg.a_grid_desc_k0_m_k1_container_[i].GetLength(I2) << "}" << std::endl; - std::cout << "arg.b_grid_desc_k0_n_k1_container_{" + std::cout << "arg.b_grid_desc_k0_n_k1{" << arg.b_grid_desc_k0_n_k1_container_[i].GetLength(I0) << ", " << arg.b_grid_desc_k0_n_k1_container_[i].GetLength(I1) << ", " << arg.b_grid_desc_k0_n_k1_container_[i].GetLength(I2) << "}" << std::endl; - std::cout << "arg.c_grid_desc_m_n_container_{ " + std::cout << "arg.c_grid_desc_m_n{" << arg.c_grid_desc_m_n_container_[i].GetLength(I0) << ", " << arg.c_grid_desc_m_n_container_[i].GetLength(I1) << "}" << std::endl; - - std::cout << "arg.c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_container_( " - << arg.c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_container_[i].GetLength(I0) - << ", " - << arg.c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_container_[i].GetLength(I1) - << ", " - << arg.c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_container_[i].GetLength(I2) - << ", " - << arg.c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_container_[i].GetLength(I3) - << ", " - << arg.c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_container_[i].GetLength(I4) - << ", " - << arg.c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_container_[i].GetLength(I5) - << ", " - << arg.c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_container_[i].GetLength(I6) - << ", " - << arg.c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_container_[i].GetLength(I7) - << " ) " << std::endl; } #endif if(!GridwiseGemm::CheckValidity(arg.a_grid_desc_k0_m_k1_container_[i], arg.b_grid_desc_k0_n_k1_container_[i], - arg.c_grid_desc_m_n_container_[i], - arg.block_2_ctile_map_container_[i])) + arg.c_grid_desc_m_n_container_[i])) { throw std::runtime_error( - "wrong! GridwiseGemm_km_kn_m0m1n0n1_xdlops_v3r1 has invalid setting"); + "wrong! GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3 has invalid setting"); } - const index_t grid_size = arg.block_2_ctile_map_container_[i].CalculateGridSize( - arg.c_grid_desc_m_n_container_[i]); + const auto [gdx, gdy, gdz] = + GridwiseGemm::CalculateGridSize(arg.c_grid_desc_m_n_container_[i]); const auto K = arg.a_grid_desc_k0_m_k1_container_[i].GetLength(I0) * arg.a_grid_desc_k0_m_k1_container_[i].GetLength(I2); if(GridwiseGemm::CalculateHasMainKBlockLoop(K)) { - const auto kernel = kernel_gemm_xdlops_v2r3< - GridwiseGemm, - ADataType, // TODO: distiguish A/B datatype - CDataType, - remove_reference_t, - remove_reference_t, - remove_reference_t< - typename GridwiseGemm::CGridDesc_M0_N0_M1_N1_M2_M3_M4_N2>, - OutElementwiseOperation, - WeiElementwiseOperation, - InElementwiseOperation, - remove_reference_t, - true>; - - ave_time += launch_and_time_kernel( - stream_config, - kernel, - dim3(grid_size), - dim3(BlockSize), - 0, - arg.p_a_grid_, - arg.p_b_grid_, - arg.p_c_grid_, - arg.a_grid_desc_k0_m_k1_container_[i], - arg.b_grid_desc_k0_n_k1_container_[i], - arg.c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_container_[i], - arg.a_element_op_, - arg.b_element_op_, - arg.c_element_op_, - arg.block_2_ctile_map_container_[i]); + const auto kernel = + kernel_gemm_xdlops_v2r3; + + ave_time += launch_and_time_kernel(stream_config, + kernel, + dim3(gdx, gdy, gdz), + dim3(BlockSize), + 0, + arg.p_a_grid_, + arg.p_b_grid_, + arg.p_c_grid_, + arg.a_grid_desc_k0_m_k1_container_[i], + arg.b_grid_desc_k0_n_k1_container_[i], + arg.c_grid_desc_m_n_container_[i]); } else { - const auto kernel = kernel_gemm_xdlops_v2r3< - GridwiseGemm, - ADataType, // TODO: distiguish A/B datatype - CDataType, - remove_reference_t, - remove_reference_t, - remove_reference_t< - typename GridwiseGemm::CGridDesc_M0_N0_M1_N1_M2_M3_M4_N2>, - OutElementwiseOperation, - WeiElementwiseOperation, - InElementwiseOperation, - remove_reference_t, - false>; - - ave_time += launch_and_time_kernel( - stream_config, - kernel, - dim3(grid_size), - dim3(BlockSize), - 0, - arg.p_a_grid_, - arg.p_b_grid_, - arg.p_c_grid_, - arg.a_grid_desc_k0_m_k1_container_[i], - arg.b_grid_desc_k0_n_k1_container_[i], - arg.c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_container_[i], - arg.a_element_op_, - arg.b_element_op_, - arg.c_element_op_, - arg.block_2_ctile_map_container_[i]); + const auto kernel = + kernel_gemm_xdlops_v2r3; + + ave_time += launch_and_time_kernel(stream_config, + kernel, + dim3(gdx, gdy, gdz), + dim3(BlockSize), + 0, + arg.p_a_grid_, + arg.p_b_grid_, + arg.p_c_grid_, + arg.a_grid_desc_k0_m_k1_container_[i], + arg.b_grid_desc_k0_n_k1_container_[i], + arg.c_grid_desc_m_n_container_[i]); } } return ave_time; @@ -1413,6 +1320,11 @@ struct DeviceConvNdBwdDataNwcKxcNwk_Xdl static bool IsSupportedArgument(const Argument& arg) { + if(!ck::is_xdl_supported()) + { + return false; + } + if constexpr(ConvBackwardDataSpecialization == ConvolutionBackwardDataSpecialization::Filter1x1Stride1Pad0) { @@ -1446,8 +1358,7 @@ struct DeviceConvNdBwdDataNwcKxcNwk_Xdl { if(!GridwiseGemm::CheckValidity(arg.a_grid_desc_k0_m_k1_container_[i], arg.b_grid_desc_k0_n_k1_container_[i], - arg.c_grid_desc_m_n_container_[i], - arg.block_2_ctile_map_container_[i])) + arg.c_grid_desc_m_n_container_[i])) { return false; } @@ -1472,10 +1383,7 @@ struct DeviceConvNdBwdDataNwcKxcNwk_Xdl std::vector conv_filter_strides, std::vector conv_filter_dilations, std::vector input_left_pads, - std::vector input_right_pads, - InElementwiseOperation in_element_op, - WeiElementwiseOperation wei_element_op, - OutElementwiseOperation out_element_op) + std::vector input_right_pads) { return Argument{p_in_grid, p_wei_grid, @@ -1489,12 +1397,7 @@ struct DeviceConvNdBwdDataNwcKxcNwk_Xdl conv_filter_strides, conv_filter_dilations, input_left_pads, - input_right_pads, - 1, - 1, - in_element_op, - wei_element_op, - out_element_op}; + input_right_pads}; } static auto MakeInvoker() { return Invoker{}; } @@ -1513,9 +1416,9 @@ struct DeviceConvNdBwdDataNwcKxcNwk_Xdl std::vector conv_filter_dilations, std::vector input_left_pads, std::vector input_right_pads, - InElementwiseOperation in_element_op, - WeiElementwiseOperation wei_element_op, - OutElementwiseOperation out_element_op) override + InElementwiseOperation, + WeiElementwiseOperation, + OutElementwiseOperation) override { return std::make_unique(static_cast(p_in_grid), static_cast(p_wei_grid), @@ -1529,12 +1432,7 @@ struct DeviceConvNdBwdDataNwcKxcNwk_Xdl conv_filter_strides, conv_filter_dilations, input_left_pads, - input_right_pads, - 1, - 1, - in_element_op, - wei_element_op, - out_element_op); + input_right_pads); } std::unique_ptr MakeInvokerPointer() override diff --git a/include/ck/tensor_operation/gpu/device/impl/device_elementwise_2d_impl.hpp b/include/ck/tensor_operation/gpu/device/impl/device_elementwise_2d_impl.hpp index 83ed6198bd3c0921988d0d867a87a617bc2c76a1..02ef29e32ddc2bbd985fb4f46c6eb3825c00c11a 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_elementwise_2d_impl.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_elementwise_2d_impl.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #pragma once @@ -13,6 +13,7 @@ #include "ck/tensor_description/tensor_descriptor_helper.hpp" #include "ck/host_utility/kernel_launch.hpp" +#include "ck/host_utility/stream_utility.hpp" namespace ck { namespace tensor_operation { @@ -171,10 +172,7 @@ struct DeviceElementwise2dImpl : public DeviceElementwise 0, ""); static_assert(NumDim_n > 0, ""); @@ -192,34 +190,10 @@ struct DeviceElementwise2dImpl : public DeviceElementwise(out_dev_buffers[I.value]); }, Number{}); - - in_grid_2d_desc_tuple_ = generate_tuple( - [&](auto I) { - return MakeDescriptor_MN(lengths, - inStridesArray[I.value], - gridSize_, - blockSize_, - num_threads_m_, - num_threads_n_); - }, - Number{}); - - out_grid_2d_desc_tuple_ = generate_tuple( - [&](auto I) { - return MakeDescriptor_MN(lengths, - outStridesArray[I.value], - gridSize_, - blockSize_, - num_threads_m_, - num_threads_n_); - }, - Number{}); } InDataTypePointerTuple in_dev_buffers_; OutDataTypePointerTuple out_dev_buffers_; - InGrid2dDescTuple in_grid_2d_desc_tuple_; - OutGrid2dDescTuple out_grid_2d_desc_tuple_; std::array lengths_; std::array, NumInput> inStridesArray_; @@ -227,15 +201,38 @@ struct DeviceElementwise2dImpl : public DeviceElementwise{}); + + auto out_grid_2d_desc_tuple = generate_tuple( + [&](auto I) { + return MakeDescriptor_MN(arg.lengths_, + arg.outStridesArray_[I.value], + gridSize, + arg.blockSize_, + num_threads_m, + num_threads_n); + }, + Number{}); + const auto kernel = kernel_elementwise_2d(out_dev_buffers[I.value]); }, Number{}); - - in_grid_1d_desc_tuple_ = generate_tuple( - [&](auto I) { - return MakeDescriptor_M( - lengths, inStridesArray[I.value], gridSize_, blockSize_); - }, - Number{}); - - out_grid_1d_desc_tuple_ = generate_tuple( - [&](auto I) { - return MakeDescriptor_M( - lengths, outStridesArray[I.value], gridSize_, blockSize_); - }, - Number{}); } InDataTypePointerTuple in_dev_buffers_; OutDataTypePointerTuple out_dev_buffers_; - InGrid1dDescTuple in_grid_1d_desc_tuple_; - OutGrid1dDescTuple out_grid_1d_desc_tuple_; std::array lengths_; std::array, NumInput> inStridesArray_; @@ -187,13 +171,28 @@ struct DeviceElementwiseImpl ElementwiseOperation elementwise_op_; index_t blockSize_; - index_t gridSize_; }; struct Invoker : public BaseInvoker { float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{}) { + index_t gridSize = getAvailableComputeUnitCount(stream_config); + + auto in_grid_1d_desc_tuple = generate_tuple( + [&](auto I) { + return MakeDescriptor_M( + arg.lengths_, arg.inStridesArray_[I.value], gridSize, arg.blockSize_); + }, + Number{}); + + auto out_grid_1d_desc_tuple = generate_tuple( + [&](auto I) { + return MakeDescriptor_M( + arg.lengths_, arg.outStridesArray_[I.value], gridSize, arg.blockSize_); + }, + Number{}); + const auto kernel = kernel_elementwise_1d + +#include "ck/utility/common_header.hpp" +#include "ck/tensor_description/tensor_descriptor.hpp" +#include "ck/tensor_description/tensor_descriptor_helper.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/device_gemm.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/grid/gridwise_gemm_dpp.hpp" +#include "ck/host_utility/device_prop.hpp" +#include "ck/host_utility/kernel_launch.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { + +template +struct DeviceGemmDpp : public DeviceGemm +{ + using GridwiseGemm = GridwiseGemm_ak0mak1_bk0nbk1_mn_dpp< + BlockSize, + ADataType, + AccDataType, + CDataType, + InMemoryDataOperationEnum::Set, + ALayout, + BLayout, + CLayout, + AElementwiseOperation, + BElementwiseOperation, + CElementwiseOperation, + GemmSpec, + MPerBlock, + NPerBlock, + KPerBlock, + MPerDpp, + NPerDpp, + AK1, + BK1, + MDppPerWave, + NDppPerWave, + ABlockTransferThreadClusterLengths_K0_M_K1, + ABlockTransferThreadClusterArrangeOrder, + ABlockTransferSrcAccessOrder, + ABlockTransferSrcVectorDim, + ABlockTransferSrcScalarPerVector, + ABlockTransferDstScalarPerVector_K1, + false, // AThreadTransferSrcResetCoordinateAfterRun, + ABlockLdsAddExtraM, + BBlockTransferThreadClusterLengths_K0_N_K1, + BBlockTransferThreadClusterArrangeOrder, + BBlockTransferSrcAccessOrder, + BBlockTransferSrcVectorDim, + BBlockTransferSrcScalarPerVector, + BBlockTransferDstScalarPerVector_K1, + false, // BThreadTransferSrcResetCoordinateAfterRun, + BBlockLdsAddExtraN, + Sequence<0, 2, 4, 1, 3, 5>, // CThreadTransferSrcDstAccessOrder, + CThreadTransferSrcDstVectorDim, + CThreadTransferDstScalarPerVector, + NumPrefetch, + PipelineVer>; + + using Argument = typename GridwiseGemm::Argument; + + // Invoker + struct Invoker : public BaseInvoker + { + float Run(const Argument& karg, const StreamConfig& stream_config = StreamConfig{}) + { + if(stream_config.log_level_ > 0) + { + karg.Print(); + } + + if(!GridwiseGemm::CheckValidity(karg)) + { + throw std::runtime_error( + "wrong! GridwiseGemm_k0mk1_k0nk1_mn_dpp has invalid setting"); + } + + const auto [gdx, gdy, gdz] = GridwiseGemm::CalculateGridSize(karg.M, karg.N); + + float ave_time = 0; + + if(GridwiseGemm::CalculateHasMainKBlockLoop(karg.K)) + { + const auto kernel = kernel_gemm_dpp; + + ave_time = launch_and_time_kernel( + stream_config, kernel, dim3(gdx, gdy, gdz), dim3(BlockSize), 0, karg); + } + else + { + const auto kernel = kernel_gemm_dpp; + + ave_time = launch_and_time_kernel( + stream_config, kernel, dim3(gdx, gdy, gdz), dim3(BlockSize), 0, karg); + } + + return ave_time; + } + + // polymorphic + float Run(const BaseArgument* p_arg, + const StreamConfig& stream_config = StreamConfig{}) override + { + return Run(*dynamic_cast(p_arg), stream_config); + } + }; + + static constexpr bool IsValidCompilationParameter() + { + // TODO: properly implement this check + return true; + } + + static bool IsSupportedArgument(const Argument& karg) + { + if(ck::get_device_name() == "gfx1030" || ck::get_device_name() == "gfx1100" || + ck::get_device_name() == "gfx1101" || ck::get_device_name() == "gfx1102") + { + return GridwiseGemm::CheckValidity(karg); + } + return false; + } + + // polymorphic + bool IsSupportedArgument(const BaseArgument* p_arg) override + { + return IsSupportedArgument(*dynamic_cast(p_arg)); + } + + static auto MakeArgument(const ADataType* p_a, + const BDataType* p_b, + CDataType* p_c, + index_t M, + index_t N, + index_t K, + index_t StrideA, + index_t StrideB, + index_t StrideC, + AElementwiseOperation, + BElementwiseOperation, + CElementwiseOperation) + { + return Argument{p_a, p_b, p_c, M, N, K, StrideA, StrideB, StrideC}; + } + + static auto MakeInvoker() { return Invoker{}; } + + // polymorphic + std::unique_ptr MakeArgumentPointer(const void* p_a, + const void* p_b, + void* p_c, + index_t M, + index_t N, + index_t K, + index_t StrideA, + index_t StrideB, + index_t StrideC, + AElementwiseOperation, + BElementwiseOperation, + CElementwiseOperation) override + { + return std::make_unique(static_cast(p_a), + static_cast(p_b), + static_cast(p_c), + M, + N, + K, + StrideA, + StrideB, + StrideC); + } + + // polymorphic + std::unique_ptr MakeInvokerPointer() override + { + return std::make_unique(Invoker{}); + } + + // polymorphic + std::string GetTypeString() const override + { + auto str = std::stringstream(); + + std::map PipelineVersionToString{{PipelineVersion::v1, "v1"}, + {PipelineVersion::v2, "v2"}}; + + // clang-format off + str << "DeviceGemmDpp" + << "<" + << BlockSize << ", " + << MPerBlock << ", " + << NPerBlock << ", " + << KPerBlock << ", " + << AK1 << ", " + << BK1 << ", " + << MPerDpp << ", " + << NPerDpp << ", " + << MDppPerWave << ", " + << MDppPerWave << ", " + << ABlockTransferSrcScalarPerVector << ", " + << ABlockTransferDstScalarPerVector_K1 << ", " + << BBlockTransferSrcScalarPerVector << ", " + << BBlockTransferDstScalarPerVector_K1 + << ">" + << " NumPrefetch: " + << NumPrefetch << ", " + << "PipelineVersion: " + << PipelineVersionToString[PipelineVer]; + // clang-format on + + return str.str(); + } +}; + +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/include/ck/tensor_operation/gpu/device/impl/device_gemm_multiple_d_dl.hpp b/include/ck/tensor_operation/gpu/device/impl/device_gemm_multiple_d_dl.hpp index c91328ff713f7eb7a1b54ef3bf923c28911751ee..ad51096db7f6d8aa1cd8b4a8c7590ff385009943 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_gemm_multiple_d_dl.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_gemm_multiple_d_dl.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #pragma once @@ -52,7 +52,7 @@ __global__ void { #if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx906__) || defined(__gfx908__) || \ defined(__gfx90a__) || defined(__gfx940__) || defined(__gfx1030__) || defined(__gfx1100__) || \ - defined(__gfx1101__) || defined(__gfx1102__)) + defined(__gfx1101__) || defined(__gfx1102__) || defined(__gfx941__) || defined(__gfx942__)) constexpr index_t shared_block_size = GridwiseGemm::GetSharedMemoryNumberOfByte() / sizeof(ABDataType); @@ -555,7 +555,8 @@ struct DeviceGemmMultipleD_Dl : public DeviceGemmMultipleD( @@ -364,11 +364,13 @@ struct DeviceGemmMultipleDLayernorm_Xdl_CShuffle using DsGridDesc_M_N = remove_cvref_t; // We have to separate mean var descriptor for gemm and layernorm bacause of different grid // layout(different padding) - using GemmMeanVarGridDesc_M_NBlock = decltype( - MakeMeanVarDescriptor_M_N, GemmMPerBlock, GemmNPerBlock>(1, 1)); + using GemmMeanVarGridDesc_M_NBlock = + decltype(MakeMeanVarDescriptor_M_N, GemmMPerBlock, GemmNPerBlock>(1, + 1)); - using GemmCountGridDesc_M_NBlock = decltype( - MakeCountDescriptor_M_N, GemmMPerBlock, GemmNPerBlock>(1, 1)); + using GemmCountGridDesc_M_NBlock = + decltype(MakeCountDescriptor_M_N, GemmMPerBlock, GemmNPerBlock>(1, + 1)); using LayernormMeanVarGridDesc_M_NBlock = decltype(MakeMeanVarDescriptor_M_N, @@ -855,8 +857,7 @@ struct DeviceGemmMultipleDLayernorm_Xdl_CShuffle static bool IsSupportedArgument(const Argument& arg) { - if(!(ck::get_device_name() == "gfx908" || ck::get_device_name() == "gfx90a" || - ck::get_device_name() == "gfx940")) + if(!ck::is_xdl_supported()) { return false; } diff --git a/include/ck/tensor_operation/gpu/device/impl/device_gemm_multiple_d_multiple_r_xdl_cshuffle.hpp b/include/ck/tensor_operation/gpu/device/impl/device_gemm_multiple_d_multiple_r_xdl_cshuffle.hpp index 4c1c3ab7bf1d10d5f8f75f88c41064db1fb51481..916f29a904eba4172b13651249263d085b16f378 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_gemm_multiple_d_multiple_r_xdl_cshuffle.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_gemm_multiple_d_multiple_r_xdl_cshuffle.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #pragma once @@ -61,7 +61,7 @@ __global__ void const Block2ETileMap block_2_etile_map) { #if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \ - defined(__gfx940__)) + defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__)) __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()]; GridwiseGemm::template Run(p_a_grid, @@ -337,10 +337,12 @@ struct DeviceGemmMultipleDMultipleR_Xdl_CShuffle RThreadTransferDstScalarPerVector_MPerBlock, LoopSched>; - using AGridDesc_AK0_M_AK1 = remove_cvref_t; - using BGridDesc_BK0_N_BK1 = remove_cvref_t; + using AGridDesc_AK0_M_AK1 = + remove_cvref_t; + using BGridDesc_BK0_N_BK1 = + remove_cvref_t; using Block2ETileMap = typename GridwiseGemm::DefaultBlock2ETileMap; @@ -555,8 +557,7 @@ struct DeviceGemmMultipleDMultipleR_Xdl_CShuffle static bool IsSupportedArgument(const Argument& arg) { - if(!(ck::get_device_name() == "gfx908" || ck::get_device_name() == "gfx90a" || - ck::get_device_name() == "gfx940")) + if(!ck::is_xdl_supported()) { return false; } diff --git a/include/ck/tensor_operation/gpu/device/impl/device_gemm_multiple_d_wmma_cshuffle.hpp b/include/ck/tensor_operation/gpu/device/impl/device_gemm_multiple_d_wmma_cshuffle.hpp index 750df31a3f6395ccf52fc27ac7705311b24e4376..44b3518e2c950e15ec3d5cfd240c443029987622 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_gemm_multiple_d_wmma_cshuffle.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_gemm_multiple_d_wmma_cshuffle.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/include/ck/tensor_operation/gpu/device/impl/device_gemm_multiple_d_xdl_cshuffle.hpp b/include/ck/tensor_operation/gpu/device/impl/device_gemm_multiple_d_xdl_cshuffle.hpp index 2488101484a6d3674be42afe79b836d61058a58a..c90c28f5a8324083c17b44738f4bbb6356e21492 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_gemm_multiple_d_xdl_cshuffle.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_gemm_multiple_d_xdl_cshuffle.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #pragma once @@ -20,7 +20,8 @@ namespace ck { template (p_a_grid, @@ -242,9 +243,13 @@ struct DeviceGemmMultipleD_Xdl_CShuffle : public DeviceGemmMultipleD; using EGridDesc_M_N = decltype(MakeEGridDescriptor_M_N(1, 1, 1)); + using ComputeDataType = EDataType; + // GridwiseGemm using GridwiseGemm = GridwiseGemmMultipleD_xdl_cshuffle< ADataType, // TODO: distinguish A/B datatype + BDataType, + ComputeDataType, AccDataType, CShuffleDataType, DsDataType, @@ -288,14 +293,18 @@ struct DeviceGemmMultipleD_Xdl_CShuffle : public DeviceGemmMultipleD; // desc for blockwise copy - using AGridDesc_AK0_M_AK1 = remove_cvref_t; - using BGridDesc_BK0_N_BK1 = remove_cvref_t; - using DsGridDesc_MBlock_MPerBlock_NBlock_NPerBlock = remove_cvref_t; - using EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock = remove_cvref_t; + using AGridDesc_AK0_M_AK1 = + remove_cvref_t; + using BGridDesc_BK0_N_BK1 = + remove_cvref_t; + using DsGridDesc_MBlock_MPerBlock_NBlock_NPerBlock = remove_cvref_t< + decltype(GridwiseGemm::MakeDsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( + DsGridDesc_M_N{}))>; + using EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock = + remove_cvref_t; // block-to-e-tile map using Block2ETileMap = @@ -438,6 +447,7 @@ struct DeviceGemmMultipleD_Xdl_CShuffle : public DeviceGemmMultipleD{}; - static auto MakeAGridDescriptor_K0_M_K1(index_t M, index_t K, index_t StrideA) - { - const index_t K0 = K / K1; - - const auto a_grid_desc_m_k = [&]() { - if constexpr(is_same::value) - { - return make_naive_tensor_descriptor(make_tuple(M, K), make_tuple(StrideA, I1)); - } - else if constexpr(is_same::value) - { - return make_naive_tensor_descriptor(make_tuple(M, K), make_tuple(I1, StrideA)); - } - }(); - - if constexpr(GemmSpec == GemmSpecialization::MNPadding) - { - const auto PadM = (MPerBlock - M % MPerBlock) % MPerBlock; - - return transform_tensor_descriptor( - a_grid_desc_m_k, - make_tuple(make_unmerge_transform(make_tuple(K0, K1Number)), - make_right_pad_transform(M, PadM)), - make_tuple(Sequence<1>{}, Sequence<0>{}), - make_tuple(Sequence<0, 2>{}, Sequence<1>{})); - } - else - { - return transform_tensor_descriptor( - a_grid_desc_m_k, - make_tuple(make_unmerge_transform(make_tuple(K0, K1Number)), - make_pass_through_transform(M)), - make_tuple(Sequence<1>{}, Sequence<0>{}), - make_tuple(Sequence<0, 2>{}, Sequence<1>{})); - } - } - - static auto MakeBGridDescriptor_K0_N_K1(index_t K, index_t N, index_t StrideB) - { - const index_t K0 = K / K1; - - const auto b_grid_desc_k_n = [&]() { - if constexpr(is_same::value) - { - return make_naive_tensor_descriptor(make_tuple(K, N), make_tuple(StrideB, I1)); - } - else if constexpr(is_same::value) - { - return make_naive_tensor_descriptor(make_tuple(K, N), make_tuple(I1, StrideB)); - } - }(); - - if constexpr(GemmSpec == GemmSpecialization::MNPadding) - { - const auto PadN = (NPerBlock - N % NPerBlock) % NPerBlock; - - return transform_tensor_descriptor( - b_grid_desc_k_n, - make_tuple(make_unmerge_transform(make_tuple(K0, K1Number)), - make_right_pad_transform(N, PadN)), - make_tuple(Sequence<0>{}, Sequence<1>{}), - make_tuple(Sequence<0, 2>{}, Sequence<1>{})); - } - else - { - return transform_tensor_descriptor( - b_grid_desc_k_n, - make_tuple(make_unmerge_transform(make_tuple(K0, K1Number)), - make_pass_through_transform(N)), - make_tuple(Sequence<0>{}, Sequence<1>{}), - make_tuple(Sequence<0, 2>{}, Sequence<1>{})); - } - } - - static auto MakeCGridDescriptor_M_N(index_t M, index_t N, index_t StrideC) - { - const auto c_grid_desc_m_n = [&]() { - if constexpr(is_same::value) - { - return make_naive_tensor_descriptor(make_tuple(M, N), make_tuple(StrideC, I1)); - } - else if constexpr(is_same::value) - { - return make_naive_tensor_descriptor(make_tuple(M, N), make_tuple(I1, StrideC)); - } - }(); - - if constexpr(GemmSpec == GemmSpecialization::MNPadding) - { - const auto PadM = (MPerBlock - M % MPerBlock) % MPerBlock; - const auto PadN = (NPerBlock - N % NPerBlock) % NPerBlock; - - return transform_tensor_descriptor( - c_grid_desc_m_n, - make_tuple(make_right_pad_transform(M, PadM), make_right_pad_transform(N, PadN)), - make_tuple(Sequence<0>{}, Sequence<1>{}), - make_tuple(Sequence<0>{}, Sequence<1>{})); - } - else - { - - return transform_tensor_descriptor( - c_grid_desc_m_n, - make_tuple(make_pass_through_transform(M), make_pass_through_transform(N)), - make_tuple(Sequence<0>{}, Sequence<1>{}), - make_tuple(Sequence<0>{}, Sequence<1>{})); - } - } - - using AGridDesc_K0_M_K1 = decltype(MakeAGridDescriptor_K0_M_K1(1, 1, 1)); - using BGridDesc_K0_N_K1 = decltype(MakeBGridDescriptor_K0_N_K1(1, 1, 1)); - using CGridDesc_M_N = decltype(MakeCGridDescriptor_M_N(1, 1, 1)); - // GridwiseGemm - using GridwiseGemm = GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3< + using GridwiseGemm = GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3_ext< BlockSize, ADataType, // TODO: distinguish A/B datatype AccDataType, CDataType, InMemoryDataOperationEnum::Set, - AGridDesc_K0_M_K1, - BGridDesc_K0_N_K1, - CGridDesc_M_N, + ALayout, + BLayout, + CLayout, AElementwiseOperation, BElementwiseOperation, CElementwiseOperation, + GemmSpec, MPerBlock, NPerBlock, K0PerBlock, @@ -232,173 +120,41 @@ struct DeviceGemmXdl : public DeviceGemm; - // Argument - struct Argument : public BaseArgument - { - Argument(const ADataType* p_a_grid, - const BDataType* p_b_grid, - CDataType* p_c_grid, - index_t M, - index_t N, - index_t K, - index_t StrideA, - index_t StrideB, - index_t StrideC, - index_t M01, - index_t N01, - AElementwiseOperation a_element_op, - BElementwiseOperation b_element_op, - CElementwiseOperation c_element_op) - : p_a_grid_{p_a_grid}, - p_b_grid_{p_b_grid}, - p_c_grid_{p_c_grid}, - a_grid_desc_k0_m_k1_{}, - b_grid_desc_k0_n_k1_{}, - c_grid_desc_m_n_{}, - c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_{}, - block_2_ctile_map_{}, - M01_{M01}, - N01_{N01}, - a_element_op_{a_element_op}, - b_element_op_{b_element_op}, - c_element_op_{c_element_op}, - kraw_{K} - { - a_grid_desc_k0_m_k1_ = DeviceGemmXdl::MakeAGridDescriptor_K0_M_K1(M, K, StrideA); - b_grid_desc_k0_n_k1_ = DeviceGemmXdl::MakeBGridDescriptor_K0_N_K1(K, N, StrideB); - c_grid_desc_m_n_ = DeviceGemmXdl::MakeCGridDescriptor_M_N(M, N, StrideC); - - block_2_ctile_map_ = - GridwiseGemm::MakeDefaultBlock2CTileMap(c_grid_desc_m_n_, M01, N01); - - if(GridwiseGemm::CheckValidity(a_grid_desc_k0_m_k1_, - b_grid_desc_k0_n_k1_, - c_grid_desc_m_n_, - block_2_ctile_map_)) - { - c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_ = - GridwiseGemm::MakeCGridDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(c_grid_desc_m_n_); - } - } - - // private: - const ADataType* p_a_grid_; - const BDataType* p_b_grid_; - CDataType* p_c_grid_; - AGridDesc_K0_M_K1 a_grid_desc_k0_m_k1_; - BGridDesc_K0_N_K1 b_grid_desc_k0_n_k1_; - CGridDesc_M_N c_grid_desc_m_n_; - typename GridwiseGemm::CGridDesc_M0_N0_M1_N1_M2_M3_M4_N2 - c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_; - typename GridwiseGemm::DefaultBlock2CTileMap block_2_ctile_map_; - index_t M01_; - index_t N01_; - AElementwiseOperation a_element_op_; - BElementwiseOperation b_element_op_; - CElementwiseOperation c_element_op_; - index_t kraw_; - }; + using Argument = typename GridwiseGemm::Argument; // Invoker struct Invoker : public BaseInvoker { - using Argument = DeviceGemmXdl::Argument; - - float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{}) + float Run(const Argument& karg, const StreamConfig& stream_config = StreamConfig{}) { -#if DEBUG_LOG + if(stream_config.log_level_ > 0) { - std::cout << "arg.a_grid_desc_k0_m_k1_{" << arg.a_grid_desc_k0_m_k1_.GetLength(I0) - << ", " << arg.a_grid_desc_k0_m_k1_.GetLength(I1) << ", " - << arg.a_grid_desc_k0_m_k1_.GetLength(I2) << "}" << std::endl; - - std::cout << "arg.b_grid_desc_k0_n_k1_{" << arg.b_grid_desc_k0_n_k1_.GetLength(I0) - << ", " << arg.b_grid_desc_k0_n_k1_.GetLength(I1) << ", " - << arg.b_grid_desc_k0_n_k1_.GetLength(I2) << "}" << std::endl; - - std::cout << "arg.c_grid_desc_m_n_{ " << arg.c_grid_desc_m_n_.GetLength(I0) << ", " - << arg.c_grid_desc_m_n_.GetLength(I1) << "}" << std::endl; + karg.Print(); } -#endif - if(!GridwiseGemm::CheckValidity(arg.a_grid_desc_k0_m_k1_, - arg.b_grid_desc_k0_n_k1_, - arg.c_grid_desc_m_n_, - arg.block_2_ctile_map_)) + if(!GridwiseGemm::CheckValidity(karg)) { throw std::runtime_error( - "wrong! GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3 has invalid setting"); + "wrong! GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3_ext has invalid setting"); } - const index_t grid_size = - arg.block_2_ctile_map_.CalculateGridSize(arg.c_grid_desc_m_n_); - - const auto K = - arg.a_grid_desc_k0_m_k1_.GetLength(I0) * arg.a_grid_desc_k0_m_k1_.GetLength(I2); + const auto [gdx, gdy, gdz] = GridwiseGemm::CalculateGridSize(karg.M, karg.N); float ave_time = 0; - if(GridwiseGemm::CalculateHasMainKBlockLoop(K)) + if(GridwiseGemm::CalculateHasMainKBlockLoop(karg.K)) { - const auto kernel = kernel_gemm_xdlops_v2r3< - GridwiseGemm, - ADataType, // TODO: distiguish A/B datatype - CDataType, - remove_reference_t, - remove_reference_t, - remove_reference_t, - AElementwiseOperation, - BElementwiseOperation, - CElementwiseOperation, - remove_reference_t, - true>; - - ave_time = launch_and_time_kernel(stream_config, - kernel, - dim3(grid_size), - dim3(BlockSize), - 0, - arg.p_a_grid_, - arg.p_b_grid_, - arg.p_c_grid_, - arg.a_grid_desc_k0_m_k1_, - arg.b_grid_desc_k0_n_k1_, - arg.c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_, - arg.a_element_op_, - arg.b_element_op_, - arg.c_element_op_, - arg.block_2_ctile_map_); + const auto kernel = kernel_gemm_xdlops_v2r3; + + ave_time = launch_and_time_kernel( + stream_config, kernel, dim3(gdx, gdy, gdz), dim3(BlockSize), 0, karg); } else { - const auto kernel = kernel_gemm_xdlops_v2r3< - GridwiseGemm, - ADataType, // TODO: distiguish A/B datatype - CDataType, - remove_reference_t, - remove_reference_t, - remove_reference_t, - AElementwiseOperation, - BElementwiseOperation, - CElementwiseOperation, - remove_reference_t, - false>; - - ave_time = launch_and_time_kernel(stream_config, - kernel, - dim3(grid_size), - dim3(BlockSize), - 0, - arg.p_a_grid_, - arg.p_b_grid_, - arg.p_c_grid_, - arg.a_grid_desc_k0_m_k1_, - arg.b_grid_desc_k0_n_k1_, - arg.c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_, - arg.a_element_op_, - arg.b_element_op_, - arg.c_element_op_, - arg.block_2_ctile_map_); + const auto kernel = kernel_gemm_xdlops_v2r3; + + ave_time = launch_and_time_kernel( + stream_config, kernel, dim3(gdx, gdy, gdz), dim3(BlockSize), 0, karg); } return ave_time; @@ -418,7 +174,7 @@ struct DeviceGemmXdl : public DeviceGemm(static_cast(p_a), static_cast(p_b), @@ -511,12 +251,7 @@ struct DeviceGemmXdl : public DeviceGemm + PipelineVersion PipelineVer = PipelineVersion::v1, + typename ComputeType = CDataType> struct DeviceGemm_Xdl_CShuffle : public DeviceGemm; + PipelineVer, + ComputeType>; using Argument = typename GridwiseGemm::Argument; @@ -188,8 +191,7 @@ struct DeviceGemm_Xdl_CShuffle : public DeviceGemm + index_t CBlockTransferScalarPerVector_NWaveNPerXDL, + typename ComputeType = CDataType, + PipelineVersion PipelineVer = PipelineVersion::v1> + struct DeviceGemmXdlSplitKCShuffle : public DeviceGemmSplitK; + PipelineVer, + ComputeType>; using Argument = typename GridwiseGemm::Argument; using DefaultBlock2CTileMap = typename GridwiseGemm::DefaultBlock2CTileMap; @@ -158,8 +162,10 @@ struct DeviceGemmXdlSplitKCShuffle : public DeviceGemmSplitK 1) - hipGetErrorString( - hipMemset(karg.p_c_grid, 0, karg.M * karg.N * sizeof(CDataType))); + hipGetErrorString(hipMemsetAsync(karg.p_c_grid, + 0, + karg.M * karg.N * sizeof(CDataType), + stream_config.stream_id_)); ave_time = launch_and_time_kernel( stream_config, kernel, dim3(gdx, gdy, gdz), dim3(BlockSize), 0, karg, b2c_map); @@ -231,6 +237,11 @@ struct DeviceGemmXdlSplitKCShuffle : public DeviceGemmSplitK +#include + +#include "ck/utility/common_header.hpp" +#include "ck/tensor_description/tensor_descriptor.hpp" +#include "ck/tensor_description/tensor_descriptor_helper.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/device_gemm_streamk.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_streamk.hpp" +#include "ck/host_utility/device_prop.hpp" +#include "ck/host_utility/kernel_launch.hpp" +#include "ck/host_utility/hip_check_error.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { + +template +struct DeviceGemmXdlStreamK : public DeviceGemmStreamK +{ + static constexpr auto I0 = Number<0>{}; + static constexpr auto I1 = Number<1>{}; + static constexpr auto I2 = Number<2>{}; + static constexpr auto I3 = Number<3>{}; + + using GridwiseGemm = GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_streamk< + BlockSize, + BlockToCTileMap_GemmStreamK, + ADataType, // TODO: distinguish A/B datatype + AccDataType, + CDataType, + ALayout, + BLayout, + CLayout, + AElementwiseOperation, + BElementwiseOperation, + CElementwiseOperation, + MPerBlock, + NPerBlock, + K0PerBlock, + MPerXDL, + NPerXDL, + K1, + MXdlPerWave, + NXdlPerWave, + ABlockTransferThreadClusterLengths_K0_M_K1, + ABlockTransferThreadClusterArrangeOrder, + ABlockTransferSrcAccessOrder, + ABlockTransferSrcVectorDim, + ABlockTransferSrcScalarPerVector, + ABlockTransferDstScalarPerVector_K1, + false, // AThreadTransferSrcResetCoordinateAfterRun, + ABlockLdsAddExtraM, + BBlockTransferThreadClusterLengths_K0_N_K1, + BBlockTransferThreadClusterArrangeOrder, + BBlockTransferSrcAccessOrder, + BBlockTransferSrcVectorDim, + BBlockTransferSrcScalarPerVector, + BBlockTransferDstScalarPerVector_K1, + false, // BThreadTransferSrcResetCoordinateAfterRun, + BBlockLdsAddExtraN, + CShuffleMRepeatPerShuffle, + CShuffleNRepeatPerShuffle, + CBlockTransferScalarPerVector_NWaveNPerXDL, + CBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock>; + + using Argument = typename GridwiseGemm::Argument; + + // Invoker + struct Invoker : public BaseInvoker + { + void Print(const Argument& karg) { karg.Print(); } + + float Run(const Argument& karg, const StreamConfig& stream_config = StreamConfig{}) + { + if(stream_config.log_level_ > 0) + { + Print(karg); + } + if(!GridwiseGemm::CheckValidity(karg)) + { + throw std::runtime_error( + "wrong! GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4r2 has invalid " + "setting"); + } + + dim3 grid_dims = karg.block_mapping.get_grid_dims(); + + float ave_time = 0; + + const auto kernel = kernel_gemm_xdlops_streamk; + + // TODO: remove clear buffer for streamk kernels + if constexpr(GridwiseGemm::Block2CTileMap::ReductionStrategy == + StreamKReductionStrategy::Atomic) + { + hipGetErrorString(hipMemsetAsync(karg.p_c_grid, + 0, + karg.M * karg.N * sizeof(CDataType), + stream_config.stream_id_)); + ave_time = launch_and_time_kernel(stream_config, + kernel, + grid_dims, + dim3(BlockSize), + 0, + karg.p_a_grid, + karg.p_b_grid, + karg.p_c_grid, + karg.p_workspace_, + karg.M, + karg.N, + karg.K, + karg.StrideA, + karg.StrideB, + karg.StrideC, + karg.block_mapping); + } + else if constexpr(GridwiseGemm::Block2CTileMap::ReductionStrategy == + StreamKReductionStrategy::Reduction) + { + char* workspace_semaphore = reinterpret_cast(karg.p_workspace_) + + karg.block_mapping.get_workspace_size_for_acc( + sizeof(typename GridwiseGemm::FloatAcc)); + auto preprocess = [&]() { + hipGetErrorString( + hipMemsetAsync(workspace_semaphore, + 0, + karg.block_mapping.get_workspace_size_for_semaphore(), + stream_config.stream_id_)); + }; + + ave_time = launch_and_time_kernel_with_preprocess(stream_config, + preprocess, + kernel, + grid_dims, + dim3(BlockSize), + 0, + karg.p_a_grid, + karg.p_b_grid, + karg.p_c_grid, + karg.p_workspace_, + karg.M, + karg.N, + karg.K, + karg.StrideA, + karg.StrideB, + karg.StrideC, + karg.block_mapping); + } + + return ave_time; + } + + // polymorphic + float Run(const BaseArgument* p_arg, + const StreamConfig& stream_config = StreamConfig{}) override + { + return Run(*dynamic_cast(p_arg), stream_config); + } + }; + + size_t GetWorkSpaceSize(const BaseArgument* pArg) const override + { + const Argument* p_arg = dynamic_cast(pArg); + if constexpr(GridwiseGemm::Block2CTileMap::ReductionStrategy == + StreamKReductionStrategy::Reduction) + { + return p_arg->block_mapping.get_workspace_size(sizeof(typename GridwiseGemm::FloatAcc)); + } + else + { + return 0; + } + } + + void SetWorkSpacePointer(BaseArgument* pArg, void* p_workspace) const override + { + Argument* pArg_ = dynamic_cast(pArg); + + pArg_->p_workspace_ = p_workspace; + } + + static constexpr bool IsValidCompilationParameter() + { + // TODO: properly implement this check + return true; + } + + static bool IsSupportedArgument(const Argument& karg) + { + if(!(ck::get_device_name() == "gfx908" || ck::get_device_name() == "gfx90a" || + ck::get_device_name() == "gfx940" || ck::get_device_name() == "gfx941" || + ck::get_device_name() == "gfx942")) + { + return false; + } + return GridwiseGemm::CheckValidity(karg); + } + + // polymorphic + bool IsSupportedArgument(const BaseArgument* p_arg) override + { + return IsSupportedArgument(*dynamic_cast(p_arg)); + } + + static auto MakeArgument(const ADataType* p_a, + const BDataType* p_b, + CDataType* p_c, + index_t M, + index_t N, + index_t K, + index_t StrideA, + index_t StrideB, + index_t StrideC, + AElementwiseOperation, + BElementwiseOperation, + CElementwiseOperation, + uint32_t NumSKBlocks = 0xffffffff) + { + const auto kernel = kernel_gemm_xdlops_streamk; + int occupancy, num_cu; + hipError_t rtn; + rtn = hipOccupancyMaxActiveBlocksPerMultiprocessor( + &occupancy, kernel, BlockSize, GridwiseGemm::GetSharedMemoryNumberOfByte()); + hip_check_error(rtn); + + hipDeviceProp_t dev_prop; + hipDevice_t dev; + rtn = hipGetDevice(&dev); + hip_check_error(rtn); + rtn = hipGetDeviceProperties(&dev_prop, dev); + hip_check_error(rtn); + num_cu = dev_prop.multiProcessorCount; + + return Argument{p_a, + p_b, + p_c, + M, + N, + K, + StrideA, + StrideB, + StrideC, + static_cast(num_cu), + static_cast(occupancy), + NumSKBlocks}; + } + + static auto MakeInvoker() { return Invoker{}; } + + // polymorphic + std::unique_ptr MakeArgumentPointer(const void* p_a, + const void* p_b, + void* p_c, + index_t M, + index_t N, + index_t K, + index_t StrideA, + index_t StrideB, + index_t StrideC, + AElementwiseOperation, + BElementwiseOperation, + CElementwiseOperation, + index_t NumSKBlocks = 0) override + { + const auto kernel = kernel_gemm_xdlops_streamk; + int occupancy, num_cu; + hipError_t rtn; + rtn = hipOccupancyMaxActiveBlocksPerMultiprocessor( + &occupancy, kernel, BlockSize, GridwiseGemm::GetSharedMemoryNumberOfByte()); + hip_check_error(rtn); + + hipDeviceProp_t dev_prop; + hipDevice_t dev; + rtn = hipGetDevice(&dev); + hip_check_error(rtn); + rtn = hipGetDeviceProperties(&dev_prop, dev); + hip_check_error(rtn); + num_cu = dev_prop.multiProcessorCount; + + return std::make_unique(reinterpret_cast(p_a), + reinterpret_cast(p_b), + reinterpret_cast(p_c), + M, + N, + K, + StrideA, + StrideB, + StrideC, + static_cast(num_cu), + static_cast(occupancy), + static_cast(NumSKBlocks)); + } + + // polymorphic + std::unique_ptr MakeInvokerPointer() override + { + return std::make_unique(Invoker{}); + } + + // polymorphic + std::string GetTypeString() const override { return GridwiseGemm::GetTypeString(); } +}; + +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/include/ck/tensor_operation/gpu/device/device_gemm_xdl_waveletmodel_cshuffle.hpp b/include/ck/tensor_operation/gpu/device/impl/device_gemm_xdl_waveletmodel_cshuffle.hpp similarity index 97% rename from include/ck/tensor_operation/gpu/device/device_gemm_xdl_waveletmodel_cshuffle.hpp rename to include/ck/tensor_operation/gpu/device/impl/device_gemm_xdl_waveletmodel_cshuffle.hpp index af38f142549c6475a7fa983585f0a845286aa16e..1b34e2dba284c6996ccc85b86f424994a1591423 100644 --- a/include/ck/tensor_operation/gpu/device/device_gemm_xdl_waveletmodel_cshuffle.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_gemm_xdl_waveletmodel_cshuffle.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #pragma once @@ -48,7 +48,7 @@ __global__ void const Block2ETileMap block_2_etile_map) { #if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \ - defined(__gfx940__)) + defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__)) __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()]; GridwiseGemm::template Run(p_a_grid, @@ -248,10 +248,12 @@ struct DeviceGemm_Xdl_WaveletModel_CShuffle : public DeviceGemm; - using AGridDesc_AK0_M_AK1 = remove_cvref_t; - using BGridDesc_BK0_N_BK1 = remove_cvref_t; + using AGridDesc_AK0_M_AK1 = + remove_cvref_t; + using BGridDesc_BK0_N_BK1 = + remove_cvref_t; using Block2ETileMap = typename GridwiseGemm::DefaultBlock2ETileMap; @@ -417,8 +419,7 @@ struct DeviceGemm_Xdl_WaveletModel_CShuffle : public DeviceGemm; using EGridDesc_M_N = decltype(MakeEGridDescriptor_M_N({}, {})); + using ComputeDataType = ADataType; + // GridwiseGemm using GridwiseGemm = GridwiseGemmMultipleD_xdl_cshuffle< ADataType, // TODO: distinguish A/B datatype + BDataType, + ComputeDataType, AccDataType, CShuffleDataType, DsDataType, @@ -400,14 +404,18 @@ struct DeviceGroupedContractionMultipleD_Xdl_CShuffle LoopSched>; // desc for blockwise copy - using AGridDesc_AK0_M_AK1 = remove_cvref_t; - using BGridDesc_BK0_N_BK1 = remove_cvref_t; - using DsGridDesc_MBlock_MPerBlock_NBlock_NPerBlock = remove_cvref_t; - using EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock = remove_cvref_t; + using AGridDesc_AK0_M_AK1 = + remove_cvref_t; + using BGridDesc_BK0_N_BK1 = + remove_cvref_t; + using DsGridDesc_MBlock_MPerBlock_NBlock_NPerBlock = remove_cvref_t< + decltype(GridwiseGemm::MakeDsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( + DsGridDesc_M_N{}))>; + using EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock = + remove_cvref_t; struct GroupedContractionBlock2ETileMap { @@ -652,11 +660,12 @@ struct DeviceGroupedContractionMultipleD_Xdl_CShuffle } } - hipGetErrorString(hipMemcpy(arg.p_workspace_, - arg.contraction_multi_d_kernel_args_.data(), - arg.contraction_multi_d_kernel_args_.size() * - sizeof(ContractionMultiDKernelArg), - hipMemcpyHostToDevice)); + hipGetErrorString(hipMemcpyWithStream(arg.p_workspace_, + arg.contraction_multi_d_kernel_args_.data(), + arg.contraction_multi_d_kernel_args_.size() * + sizeof(ContractionMultiDKernelArg), + hipMemcpyHostToDevice, + stream_config.stream_id_)); float ave_time = 0; @@ -704,8 +713,7 @@ struct DeviceGroupedContractionMultipleD_Xdl_CShuffle static bool IsSupportedArgument(const Argument& arg) { - if(!(ck::get_device_name() == "gfx908" || ck::get_device_name() == "gfx90a" || - ck::get_device_name() == "gfx940")) + if(!ck::is_xdl_supported()) { return false; } diff --git a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_data_multiple_d_xdl_cshuffle_v1.hpp b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_data_multiple_d_xdl_cshuffle_v1.hpp index 92c20a308f9bc3b8c5f98581ce3e627d30420ebb..c828269acff6e02c8fcc256e4ea5cfbb9a3aa5be 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_data_multiple_d_xdl_cshuffle_v1.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_data_multiple_d_xdl_cshuffle_v1.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #pragma once @@ -131,7 +131,7 @@ __global__ void const ComputePtrOffsetOfBatch compute_ptr_offset_of_batch) { #if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \ - defined(__gfx940__)) + defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__)) // offset base pointer for each work-group const index_t num_blocks_per_batch = __builtin_amdgcn_readfirstlane(get_grid_size() / batch_count); @@ -258,7 +258,8 @@ struct DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1 CDEElementwiseOp> { // FIXME - static_assert(NDimSpatial == 2, "wrong! only implemented for 2D now"); + static_assert(NDimSpatial == 2 || NDimSpatial == 3, + "wrong! only implemented for 2D and 3D now"); using DeviceOp = DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1; @@ -279,6 +280,7 @@ struct DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1 BK1, MPerBlock, NPerBlock, + KPerBlock, DoPadGemmM, DoPadGemmN>{}; @@ -354,6 +356,8 @@ struct DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1 // GridwiseGemm using GridwiseGemm = GridwiseGemmMultipleD_xdl_cshuffle< + ABDataType, // TODO: distinguish A/B datatype + ABDataType, // TODO: distinguish A/B datatype ABDataType, // TODO: distinguish A/B datatype AccDataType, CShuffleDataType, @@ -421,10 +425,12 @@ struct DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1 using AGridDesc_M_K = decltype(transform_k0_m_k1_to_m_k(AGridDesc_AK0_M_AK1{})); using BGridDesc_N_K = decltype(transform_k0_m_k1_to_m_k(BGridDesc_BK0_N_BK1{})); - using DsGridDesc_MBlock_MPerBlock_NBlock_NPerBlock = decltype( - GridwiseGemm::MakeDsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(DsGridDesc_M_N{})); - using EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock = decltype( - GridwiseGemm::MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(EGridDesc_M_N{})); + using DsGridDesc_MBlock_MPerBlock_NBlock_NPerBlock = + decltype(GridwiseGemm::MakeDsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( + DsGridDesc_M_N{})); + using EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock = + decltype(GridwiseGemm::MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( + EGridDesc_M_N{})); // block-to-e-tile map using Block2ETileMap = @@ -459,7 +465,6 @@ struct DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1 p_ds_grid_{}, p_e_grid_{static_cast(p_e)}, num_group_{a_g_n_k_wos_lengths[0]}, - num_gemm_{}, a_element_op_{a_element_op}, b_element_op_{b_element_op}, cde_element_op_{cde_element_op}, @@ -492,133 +497,172 @@ struct DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1 compute_ptr_offset_of_batch_.BatchStrideDs_(i) = ds_g_n_c_wis_strides[i][0]; }); + static constexpr auto NonSpatialDimsNum = Number<3>{}; + + static constexpr auto DIdx = Number{}; + static constexpr auto HIdx = + NDimSpatial == 2 ? Number{} : Number{}; + static constexpr auto WIdx = NDimSpatial == 2 ? Number{} + : Number{}; + + static constexpr auto ZIdx = Number{}; + static constexpr auto YIdx = + NDimSpatial == 2 ? Number{} : Number{}; + static constexpr auto XIdx = NDimSpatial == 2 ? Number{} + : Number{}; + // problem definition - const index_t Y = b_g_k_c_xs_lengths[3]; - const index_t X = b_g_k_c_xs_lengths[4]; + const index_t Z = b_g_k_c_xs_lengths[ZIdx]; + const index_t Y = b_g_k_c_xs_lengths[YIdx]; + const index_t X = b_g_k_c_xs_lengths[XIdx]; - const index_t ConvStrideH = conv_filter_strides_[0]; - const index_t ConvStrideW = conv_filter_strides_[1]; + const index_t ConvStrideD = conv_filter_strides[DIdx - NonSpatialDimsNum]; + const index_t ConvStrideH = conv_filter_strides[HIdx - NonSpatialDimsNum]; + const index_t ConvStrideW = conv_filter_strides[WIdx - NonSpatialDimsNum]; - const index_t ConvDilationH = conv_filter_dilations_[0]; - const index_t ConvDilationW = conv_filter_dilations_[1]; + const index_t ConvDilationD = conv_filter_dilations[DIdx - NonSpatialDimsNum]; + const index_t ConvDilationH = conv_filter_dilations[HIdx - NonSpatialDimsNum]; + const index_t ConvDilationW = conv_filter_dilations[WIdx - NonSpatialDimsNum]; + const auto GcdStrideDilationD = math::gcd(ConvStrideD, ConvDilationD); const auto GcdStrideDilationH = math::gcd(ConvStrideH, ConvDilationH); const auto GcdStrideDilationW = math::gcd(ConvStrideW, ConvDilationW); + const auto ZTilde = NDimSpatial == 3 ? ConvStrideD / GcdStrideDilationD : 1; const auto YTilde = ConvStrideH / GcdStrideDilationH; const auto XTilde = ConvStrideW / GcdStrideDilationW; - // number of GEMM - num_gemm_ = YTilde * XTilde; - - for(index_t i_ytilde = 0; i_ytilde < YTilde; ++i_ytilde) + for(index_t i_ztilde = 0; i_ztilde < ZTilde; ++i_ztilde) { - for(index_t i_xtilde = 0; i_xtilde < XTilde; ++i_xtilde) - { - // check slice is valid - const auto YDotSlice = math::integer_divide_ceil(Y - i_ytilde, YTilde); - const auto XDotSlice = math::integer_divide_ceil(X - i_xtilde, XTilde); - if(YDotSlice * XDotSlice <= 0) + for(index_t i_ytilde = 0; i_ytilde < YTilde; ++i_ytilde) + { + for(index_t i_xtilde = 0; i_xtilde < XTilde; ++i_xtilde) { - continue; - } - - const auto a_grid_desc_ak0_m_ak1 = - transform_conv_to_gemm.template MakeADescriptor_AK0_M_AK1( - a_g_n_k_wos_lengths, - a_g_n_k_wos_strides, - b_g_k_c_xs_lengths, - b_g_k_c_xs_strides, - e_g_n_c_wis_lengths, - e_g_n_c_wis_strides, - conv_filter_strides, - conv_filter_dilations, - input_left_pads, - input_right_pads, - {i_ytilde, i_xtilde}); - - const auto b_grid_desc_bk0_n_bk1 = - transform_conv_to_gemm.template MakeBDescriptor_BK0_N_BK1( - a_g_n_k_wos_lengths, - a_g_n_k_wos_strides, - b_g_k_c_xs_lengths, - b_g_k_c_xs_strides, - e_g_n_c_wis_lengths, - e_g_n_c_wis_strides, - conv_filter_strides, - conv_filter_dilations, - input_left_pads, - input_right_pads, - {i_ytilde, i_xtilde}); - - DsGridDesc_M_N ds_grid_desc_m_n; - - // populate Ds desc - static_for<0, NumDTensor, 1>{}([&](auto i) { - using DLayout = remove_cvref_t>; - - ds_grid_desc_m_n(i) = - transform_conv_to_gemm.template MakeCDescriptor_M_N( + // check slice is valid + const auto ZDotSlice = + NDimSpatial == 3 ? math::integer_divide_ceil(Z - i_ztilde, ZTilde) : 1; + const auto YDotSlice = math::integer_divide_ceil(Y - i_ytilde, YTilde); + const auto XDotSlice = math::integer_divide_ceil(X - i_xtilde, XTilde); + + if(YDotSlice * XDotSlice * ZDotSlice <= 0) + { + continue; + } + + std::array tildes; + if constexpr(NDimSpatial == 2) + { + tildes = {i_ytilde, i_xtilde}; + } + else if constexpr(NDimSpatial == 3) + { + tildes = {i_ztilde, i_ytilde, i_xtilde}; + } + else + { + throw std::runtime_error("wrong! only implemented for 2D and 3D now"); + } + + const auto a_grid_desc_ak0_m_ak1 = + transform_conv_to_gemm.template MakeADescriptor_AK0_M_AK1( a_g_n_k_wos_lengths, a_g_n_k_wos_strides, b_g_k_c_xs_lengths, b_g_k_c_xs_strides, - ds_g_n_c_wis_lengths[i], - ds_g_n_c_wis_strides[i], + e_g_n_c_wis_lengths, + e_g_n_c_wis_strides, conv_filter_strides, conv_filter_dilations, input_left_pads, input_right_pads, - {i_ytilde, i_xtilde}); - }); - - const auto e_grid_desc_m_n = - transform_conv_to_gemm.template MakeCDescriptor_M_N( - a_g_n_k_wos_lengths, - a_g_n_k_wos_strides, - b_g_k_c_xs_lengths, - b_g_k_c_xs_strides, - e_g_n_c_wis_lengths, - e_g_n_c_wis_strides, - conv_filter_strides, - conv_filter_dilations, - input_left_pads, - input_right_pads, - {i_ytilde, i_xtilde}); - - // desc for problem definition - const auto a_grid_desc_m_k = transform_k0_m_k1_to_m_k(a_grid_desc_ak0_m_ak1); - const auto b_grid_desc_n_k = transform_k0_m_k1_to_m_k(b_grid_desc_bk0_n_bk1); - - a_grid_desc_m_k_container_.push_back(a_grid_desc_m_k); - b_grid_desc_n_k_container_.push_back(b_grid_desc_n_k); - ds_grid_desc_m_n_container_.push_back(ds_grid_desc_m_n); - e_grid_desc_m_n_container_.push_back(e_grid_desc_m_n); - - // desc for blockwise copy - a_grid_desc_ak0_m_ak1_container_.push_back(a_grid_desc_ak0_m_ak1); - b_grid_desc_bk0_n_bk1_container_.push_back(b_grid_desc_bk0_n_bk1); - - // block-to-e-tile-map - auto block_2_etile_map = - GridwiseGemm::MakeDefaultBlock2ETileMap(e_grid_desc_m_n); - - block_2_etile_map_container_.push_back(block_2_etile_map); - - if(GridwiseGemm::CheckValidity(a_grid_desc_m_k, - b_grid_desc_n_k, - ds_grid_desc_m_n, - e_grid_desc_m_n, - block_2_etile_map)) - { - ds_grid_desc_mblock_mperblock_nblock_nperblock_container_.push_back( - GridwiseGemm::MakeDsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( - ds_grid_desc_m_n)); + tildes); - e_grid_desc_mblock_mperblock_nblock_nperblock_container_.push_back( - GridwiseGemm::MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( - e_grid_desc_m_n)); + const auto b_grid_desc_bk0_n_bk1 = + transform_conv_to_gemm.template MakeBDescriptor_BK0_N_BK1( + a_g_n_k_wos_lengths, + a_g_n_k_wos_strides, + b_g_k_c_xs_lengths, + b_g_k_c_xs_strides, + e_g_n_c_wis_lengths, + e_g_n_c_wis_strides, + conv_filter_strides, + conv_filter_dilations, + input_left_pads, + input_right_pads, + tildes); + + DsGridDesc_M_N ds_grid_desc_m_n; + + // populate Ds desc + static_for<0, NumDTensor, 1>{}([&](auto i) { + using DLayout = remove_cvref_t>; + + ds_grid_desc_m_n(i) = + transform_conv_to_gemm.template MakeCDescriptor_M_N( + a_g_n_k_wos_lengths, + a_g_n_k_wos_strides, + b_g_k_c_xs_lengths, + b_g_k_c_xs_strides, + ds_g_n_c_wis_lengths[i], + ds_g_n_c_wis_strides[i], + conv_filter_strides, + conv_filter_dilations, + input_left_pads, + input_right_pads, + tildes); + }); + + const auto e_grid_desc_m_n = + transform_conv_to_gemm.template MakeCDescriptor_M_N( + a_g_n_k_wos_lengths, + a_g_n_k_wos_strides, + b_g_k_c_xs_lengths, + b_g_k_c_xs_strides, + e_g_n_c_wis_lengths, + e_g_n_c_wis_strides, + conv_filter_strides, + conv_filter_dilations, + input_left_pads, + input_right_pads, + tildes); + + // desc for problem definition + const auto a_grid_desc_m_k = + transform_k0_m_k1_to_m_k(a_grid_desc_ak0_m_ak1); + const auto b_grid_desc_n_k = + transform_k0_m_k1_to_m_k(b_grid_desc_bk0_n_bk1); + + a_grid_desc_m_k_container_.push_back(a_grid_desc_m_k); + b_grid_desc_n_k_container_.push_back(b_grid_desc_n_k); + ds_grid_desc_m_n_container_.push_back(ds_grid_desc_m_n); + e_grid_desc_m_n_container_.push_back(e_grid_desc_m_n); + + // desc for blockwise copy + a_grid_desc_ak0_m_ak1_container_.push_back(a_grid_desc_ak0_m_ak1); + b_grid_desc_bk0_n_bk1_container_.push_back(b_grid_desc_bk0_n_bk1); + + // block-to-e-tile-map + auto block_2_etile_map = + GridwiseGemm::MakeDefaultBlock2ETileMap(e_grid_desc_m_n); + + block_2_etile_map_container_.push_back(block_2_etile_map); + + if(GridwiseGemm::CheckValidity(a_grid_desc_m_k, + b_grid_desc_n_k, + ds_grid_desc_m_n, + e_grid_desc_m_n, + block_2_etile_map)) + { + ds_grid_desc_mblock_mperblock_nblock_nperblock_container_.push_back( + GridwiseGemm:: + MakeDsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( + ds_grid_desc_m_n)); + + e_grid_desc_mblock_mperblock_nblock_nperblock_container_.push_back( + GridwiseGemm::MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( + e_grid_desc_m_n)); + } } } } @@ -626,7 +670,7 @@ struct DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1 void Print() const { - for(index_t i = 0; i < num_gemm_; i++) + for(std::size_t i = 0; i < a_grid_desc_ak0_m_ak1_container_.size(); i++) { std::cout << "a_grid_desc_ak0_m_ak1_container_" << a_grid_desc_ak0_m_ak1_container_[i] << std::endl; @@ -654,7 +698,6 @@ struct DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1 // tensor descriptor for problem definition index_t num_group_; - index_t num_gemm_; std::vector a_grid_desc_m_k_container_; std::vector b_grid_desc_n_k_container_; std::vector ds_grid_desc_m_n_container_; @@ -708,7 +751,7 @@ struct DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1 float ave_time = 0; - for(index_t i = 0; i < arg.num_gemm_; i++) + for(std::size_t i = 0; i < arg.a_grid_desc_ak0_m_ak1_container_.size(); i++) { if(!GridwiseGemm::CheckValidity(arg.a_grid_desc_m_k_container_[i], arg.b_grid_desc_n_k_container_[i], @@ -788,6 +831,11 @@ struct DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1 static bool IsSupportedArgument(const Argument& arg) { + if(!ck::is_xdl_supported()) + { + return false; + } + const index_t ConvK = arg.b_g_k_c_xs_lengths_[1]; const index_t ConvC = arg.b_g_k_c_xs_lengths_[2]; @@ -807,7 +855,10 @@ struct DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1 } // vector load for A matrix from global memory to LDS - if constexpr(is_same_v) + if constexpr(is_same_v || + is_same_v || + is_same_v || + is_same_v) { if(!(ABlockTransferSrcVectorDim == 2 && ConvK % ABlockTransferSrcScalarPerVector == 0)) { @@ -820,7 +871,8 @@ struct DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1 } // vector load for B matrix from global memory to LDS - if constexpr(is_same_v) + if constexpr(is_same_v || + is_same_v) { if(!(BBlockTransferSrcVectorDim == 1 && ConvC % BBlockTransferSrcScalarPerVector == 0)) { @@ -839,7 +891,9 @@ struct DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1 using DLayout = remove_cvref_t>; if constexpr(is_same_v || + is_same_v || is_same_v || + is_same_v || is_same_v || is_same_v || is_same_v) @@ -862,7 +916,10 @@ struct DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1 } // vector store for E - if constexpr(is_same_v) + if constexpr(is_same_v || + is_same_v || + is_same_v || + is_same_v) { // vector store C matrix into global memory if(!(ConvC % CDEBlockTransferScalarPerVector_NPerBlock == 0)) diff --git a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_gnwc_gkxc_gnwk_dl.hpp b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_gnwc_gkxc_gnwk_dl.hpp index 0473eaf7625fe64b99736e45f96650ee076780f8..198751cdf350f7f9d828edda641b8ef217e98af6 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_gnwc_gkxc_gnwk_dl.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_gnwc_gkxc_gnwk_dl.hpp @@ -195,17 +195,17 @@ struct DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Dl template ::type = false> static auto MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N( - ck::index_t N, - ck::index_t K, - ck::index_t C, - std::array input_spatial_lengths, - std::array filter_spatial_lengths, - std::array output_spatial_lengths, - std::array conv_filter_strides, - std::array conv_filter_dilations, - std::array input_left_pads, - std::array input_right_pads, - ck::index_t batch_k) + const ck::index_t N, + const ck::index_t K, + const ck::index_t C, + const std::array& input_spatial_lengths, + const std::array& filter_spatial_lengths, + const std::array& output_spatial_lengths, + const std::array& conv_filter_strides, + const std::array& conv_filter_dilations, + const std::array& input_left_pads, + const std::array& input_right_pads, + const ck::index_t batch_k) { using namespace ck; @@ -347,17 +347,17 @@ struct DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Dl } // function end template ::type = false> static auto MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N( - ck::index_t N, - ck::index_t K, - ck::index_t C, - std::array input_spatial_lengths, - std::array filter_spatial_lengths, - std::array output_spatial_lengths, - std::array conv_filter_strides, - std::array conv_filter_dilations, - std::array input_left_pads, - std::array input_right_pads, - ck::index_t batch_k) + const ck::index_t N, + const ck::index_t K, + const ck::index_t C, + const std::array& input_spatial_lengths, + const std::array& filter_spatial_lengths, + const std::array& output_spatial_lengths, + const std::array& conv_filter_strides, + const std::array& conv_filter_dilations, + const std::array& input_left_pads, + const std::array& input_right_pads, + const ck::index_t batch_k) { using namespace ck; @@ -515,17 +515,17 @@ struct DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Dl template ::type = false> static auto MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N( - ck::index_t N, - ck::index_t K, - ck::index_t C, - std::array input_spatial_lengths, - std::array filter_spatial_lengths, - std::array output_spatial_lengths, - std::array conv_filter_strides, - std::array conv_filter_dilations, - std::array input_left_pads, - std::array input_right_pads, - ck::index_t batch_k) + const ck::index_t N, + const ck::index_t K, + const ck::index_t C, + const std::array& input_spatial_lengths, + const std::array& filter_spatial_lengths, + const std::array& output_spatial_lengths, + const std::array& conv_filter_strides, + const std::array& conv_filter_dilations, + const std::array& input_left_pads, + const std::array& input_right_pads, + const ck::index_t batch_k) { using namespace ck; @@ -784,17 +784,16 @@ struct DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Dl Argument(const InDataType* p_in_grid, WeiDataType* p_wei_grid, const OutDataType* p_out_grid, - ck::index_t G, - ck::index_t N, - ck::index_t K, - ck::index_t C, - std::array input_spatial_lengths, - std::array filter_spatial_lengths, - std::array output_spatial_lengths, - std::array conv_filter_strides, - std::array conv_filter_dilations, - std::array input_left_pads, - std::array input_right_pads, + const std::array& a_g_n_c_wis_lengths, // input + const std::array& /*a_g_n_c_wis_strides*/, + const std::array& b_g_k_c_xs_lengths, // weight + const std::array& /*b_g_k_c_xs_strides*/, + const std::array& e_g_n_k_wos_lengths, // output + const std::array& /*e_g_n_k_wos_strides*/, + const std::array& conv_filter_strides, + const std::array& conv_filter_dilations, + const std::array& input_left_pads, + const std::array& input_right_pads, InElementwiseOperation in_element_op, WeiElementwiseOperation wei_element_op, OutElementwiseOperation out_element_op, @@ -810,27 +809,38 @@ struct DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Dl a_element_op_{out_element_op}, b_element_op_{wei_element_op}, c_element_op_{in_element_op}, - Conv_G_{G}, - Conv_N_{N}, - Conv_K_{K}, - Conv_C_{C}, - input_spatial_lengths_{input_spatial_lengths}, - filter_spatial_lengths_{filter_spatial_lengths}, - output_spatial_lengths_{output_spatial_lengths}, + Conv_G_{a_g_n_c_wis_lengths[0]}, + Conv_N_{a_g_n_c_wis_lengths[1]}, + Conv_K_{b_g_k_c_xs_lengths[1]}, + Conv_C_{a_g_n_c_wis_lengths[2]}, + input_spatial_lengths_{}, + filter_spatial_lengths_{}, + output_spatial_lengths_{}, conv_filter_strides_{conv_filter_strides}, conv_filter_dilations_{conv_filter_dilations}, input_left_pads_{input_left_pads}, input_right_pads_{input_right_pads}, k_batch_{split_k} { + constexpr index_t spatial_offset = 3; + std::copy(begin(a_g_n_c_wis_lengths) + spatial_offset, + end(a_g_n_c_wis_lengths), + begin(input_spatial_lengths_)); + std::copy(begin(b_g_k_c_xs_lengths) + spatial_offset, + end(b_g_k_c_xs_lengths), + begin(filter_spatial_lengths_)); + std::copy(begin(e_g_n_k_wos_lengths) + spatial_offset, + end(e_g_n_k_wos_lengths), + begin(output_spatial_lengths_)); + const auto descs = DeviceOp::MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N( - N, - K, - C, - input_spatial_lengths, - filter_spatial_lengths, - output_spatial_lengths, + Conv_N_, + Conv_K_, + Conv_C_, + input_spatial_lengths_, + filter_spatial_lengths_, + output_spatial_lengths_, conv_filter_strides, conv_filter_dilations, input_left_pads, @@ -854,21 +864,21 @@ struct DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Dl // A/B/C Batch Stride compute_ptr_offset_of_batch_.BatchStrideA_ = - N * K * - std::accumulate(begin(output_spatial_lengths), - end(output_spatial_lengths), + Conv_N_ * Conv_K_ * + std::accumulate(begin(output_spatial_lengths_), + end(output_spatial_lengths_), index_t{1}, std::multiplies<>{}); compute_ptr_offset_of_batch_.BatchStrideB_ = - N * C * - std::accumulate(begin(input_spatial_lengths), - end(input_spatial_lengths), + Conv_N_ * Conv_C_ * + std::accumulate(begin(input_spatial_lengths_), + end(input_spatial_lengths_), index_t{1}, std::multiplies<>{}); compute_ptr_offset_of_batch_.BatchStrideC_ = - K * C * - std::accumulate(begin(filter_spatial_lengths), - end(filter_spatial_lengths), + Conv_K_ * Conv_C_ * + std::accumulate(begin(filter_spatial_lengths_), + end(filter_spatial_lengths_), index_t{1}, std::multiplies<>{}); } @@ -897,18 +907,18 @@ struct DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Dl InElementwiseOperation c_element_op_; // for checking IsSupportedArgument() - index_t Conv_G_; - index_t Conv_N_; - index_t Conv_K_; - index_t Conv_C_; + const index_t Conv_G_; + const index_t Conv_N_; + const index_t Conv_K_; + const index_t Conv_C_; std::array input_spatial_lengths_; std::array filter_spatial_lengths_; std::array output_spatial_lengths_; - std::array conv_filter_strides_; - std::array conv_filter_dilations_; - std::array input_left_pads_; - std::array input_right_pads_; + const std::array& conv_filter_strides_; + const std::array& conv_filter_dilations_; + const std::array& input_left_pads_; + const std::array& input_right_pads_; index_t k_batch_; }; @@ -1108,35 +1118,34 @@ struct DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Dl return IsSupportedArgument(*dynamic_cast(p_arg)); } - static auto MakeArgument(const InDataType* p_in_grid, - WeiDataType* p_wei_grid, - const OutDataType* p_out_grid, - ck::index_t G, - ck::index_t N, - ck::index_t K, - ck::index_t C, - std::array input_spatial_lengths, - std::array filter_spatial_lengths, - std::array output_spatial_lengths, - std::array conv_filter_strides, - std::array conv_filter_dilations, - std::array input_left_pads, - std::array input_right_pads, - InElementwiseOperation in_element_op, - WeiElementwiseOperation wei_element_op, - OutElementwiseOperation out_element_op, - ck::index_t split_k) + static auto + MakeArgument(const InDataType* p_in_grid, + WeiDataType* p_wei_grid, + const OutDataType* p_out_grid, + const std::array& a_g_n_c_wis_lengths, // input + const std::array& a_g_n_c_wis_strides, + const std::array& b_g_k_c_xs_lengths, // weight + const std::array& b_g_k_c_xs_strides, + const std::array& e_g_n_k_wos_lengths, // output + const std::array& e_g_n_k_wos_strides, + const std::array& conv_filter_strides, + const std::array& conv_filter_dilations, + const std::array& input_left_pads, + const std::array& input_right_pads, + InElementwiseOperation in_element_op, + WeiElementwiseOperation wei_element_op, + OutElementwiseOperation out_element_op, + ck::index_t split_k) { return Argument{p_in_grid, p_wei_grid, p_out_grid, - G, - N, - K, - C, - input_spatial_lengths, - filter_spatial_lengths, - output_spatial_lengths, + a_g_n_c_wis_lengths, // input + a_g_n_c_wis_strides, + b_g_k_c_xs_lengths, // weight + b_g_k_c_xs_strides, + e_g_n_k_wos_lengths, // output + e_g_n_k_wos_strides, conv_filter_strides, conv_filter_dilations, input_left_pads, @@ -1153,17 +1162,16 @@ struct DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Dl MakeArgumentPointer(const void* p_in_grid, void* p_wei_grid, const void* p_out_grid, - ck::index_t G, - ck::index_t N, - ck::index_t K, - ck::index_t C, - std::array input_spatial_lengths, - std::array filter_spatial_lengths, - std::array output_spatial_lengths, - std::array conv_filter_strides, - std::array conv_filter_dilations, - std::array input_left_pads, - std::array input_right_pads, + const std::array& a_g_n_c_wis_lengths, // input + const std::array& a_g_n_c_wis_strides, + const std::array& b_g_k_c_xs_lengths, // weight + const std::array& b_g_k_c_xs_strides, + const std::array& e_g_n_k_wos_lengths, // output + const std::array& e_g_n_k_wos_strides, + const std::array& conv_filter_strides, + const std::array& conv_filter_dilations, + const std::array& input_left_pads, + const std::array& input_right_pads, InElementwiseOperation in_element_op, WeiElementwiseOperation wei_element_op, OutElementwiseOperation out_element_op, @@ -1172,13 +1180,12 @@ struct DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Dl return std::make_unique(static_cast(p_in_grid), static_cast(p_wei_grid), static_cast(p_out_grid), - G, - N, - K, - C, - input_spatial_lengths, - filter_spatial_lengths, - output_spatial_lengths, + a_g_n_c_wis_lengths, // input + a_g_n_c_wis_strides, + b_g_k_c_xs_lengths, // weight + b_g_k_c_xs_strides, + e_g_n_k_wos_lengths, // output + e_g_n_k_wos_strides, conv_filter_strides, conv_filter_dilations, input_left_pads, diff --git a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_gnwc_gkxc_gnwk_xdl_cshuffle.hpp b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_xdl_cshuffle.hpp similarity index 64% rename from include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_gnwc_gkxc_gnwk_xdl_cshuffle.hpp rename to include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_xdl_cshuffle.hpp index c921c9f1b400f76627376d4865498aa27e7b49c3..71c528e4c92e339b334d78faab82fe10c7bf914e 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_gnwc_gkxc_gnwk_xdl_cshuffle.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_xdl_cshuffle.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #pragma once @@ -79,7 +79,7 @@ __global__ void const ComputePtrOffsetOfBatch compute_ptr_offset_of_batch) { #if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \ - defined(__gfx940__)) + defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__)) const index_t num_blocks_per_batch = __builtin_amdgcn_readfirstlane(get_grid_size() / batch_count); const index_t g_idx = __builtin_amdgcn_readfirstlane(get_block_1d_id() / num_blocks_per_batch); @@ -126,6 +126,9 @@ __global__ void // out[N, Ho, Wo, K] = in[N, Hi, Wi, C] * wei[K, Y, X, C] template -struct DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Xdl_CShuffle - : public DeviceGroupedConvBwdWeight< - NDimSpatial, - ck::tuple_element_t>, - ck::tuple_element_t>, - ck::tuple_element_t>, - InDataType, - WeiDataType, - OutDataType, - InElementwiseOperation, - WeiElementwiseOperation, - OutElementwiseOperation> +struct DeviceGroupedConvBwdWeight_Xdl_CShuffle + : public DeviceGroupedConvBwdWeight { - using DeviceOp = DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Xdl_CShuffle; + using DeviceOp = DeviceGroupedConvBwdWeight_Xdl_CShuffle; using ADataType = OutDataType; using BDataType = InDataType; @@ -196,6 +189,30 @@ struct DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Xdl_CShuffle // TODO make A/B datatype different using ABDataType = InDataType; + // 1d + static constexpr bool is_GNWK_GKXC_GNWC = + is_same_v && + is_same_v && + is_same_v; + // 2d + static constexpr bool is_NHWGK_GKYXC_NHWGC = + is_same_v && + is_same_v && + is_same_v; + static constexpr bool is_GNHWK_GKYXC_GNHWC = + is_same_v && + is_same_v && + is_same_v; + // 3d + static constexpr bool is_NDHWGK_GKZYXC_NDHWGC = + is_same_v && + is_same_v && + is_same_v; + static constexpr bool is_GNDHWK_GKZYXC_GNDHWC = + is_same_v && + is_same_v && + is_same_v; + static constexpr auto I0 = Number<0>{}; static constexpr auto I1 = Number<1>{}; static constexpr auto I2 = Number<2>{}; @@ -220,19 +237,132 @@ struct DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Xdl_CShuffle static constexpr auto BBlockLdsN0PerBlock = NPerBlock / BBlockLdsN1PerBlock; static constexpr auto BBlockLdsN1Padding = 4; + template ::type = false> + constexpr static auto + make_out_grid_desc(const ck::index_t N, + const ck::index_t Ho, + const ck::index_t Wo, + const ck::index_t K, + const std::array& output_strides) + { + const index_t WoStride = output_strides[4]; + const auto KStride = Number<1>{}; + return make_naive_tensor_descriptor(make_tuple(N * Ho * Wo, K), + make_tuple(WoStride, KStride)); + } + + template ::type = false> + constexpr static auto + make_in_grid_desc(const ck::index_t N, + const ck::index_t Hi, + const ck::index_t Wi, + const ck::index_t C, + const std::array& input_strides) + { + const index_t NStride = input_strides[1]; + const index_t HiStride = input_strides[3]; + const index_t WiStride = input_strides[4]; + const auto CStride = input_strides[2]; + if constexpr(ConvBackwardWeightSpecialization == + ConvolutionBackwardWeightSpecialization::Filter1x1Stride1Pad0) + { + return make_naive_tensor_descriptor(make_tuple(N * Hi * Wi, C), + make_tuple(WiStride, CStride)); + } + else + { + return make_naive_tensor_descriptor(make_tuple(N, Hi, Wi, C), + make_tuple(NStride, HiStride, WiStride, CStride)); + } + } + + template ::type = false> + constexpr static auto + make_wei_grid_desc(const ck::index_t K, + const ck::index_t Y, + const ck::index_t X, + const ck::index_t C, + const std::array& weights_strides) + { + const auto CStride = Number<1>{}; + const auto KStride = weights_strides[1]; + return make_naive_tensor_descriptor(make_tuple(K, Y * X * C), make_tuple(KStride, CStride)); + } + + template ::type = false> + constexpr static auto + make_out_grid_desc(const ck::index_t N, + const ck::index_t Do, + const ck::index_t Ho, + const ck::index_t Wo, + const ck::index_t K, + const std::array& output_strides) + { + const index_t WoStride = output_strides[5]; + const auto KStride = Number<1>{}; + return make_naive_tensor_descriptor(make_tuple(N * Do * Ho * Wo, K), + make_tuple(WoStride, KStride)); + } + + template ::type = false> + constexpr static auto + make_in_grid_desc(const ck::index_t N, + const ck::index_t Di, + const ck::index_t Hi, + const ck::index_t Wi, + const ck::index_t C, + const std::array& input_strides) + { + const index_t NStride = input_strides[1]; + const index_t DiStride = input_strides[3]; + const index_t HiStride = input_strides[4]; + const index_t WiStride = input_strides[5]; + const auto CStride = input_strides[2]; + if constexpr(ConvBackwardWeightSpecialization == + ConvolutionBackwardWeightSpecialization::Filter1x1Stride1Pad0) + { + return make_naive_tensor_descriptor(make_tuple(N * Di * Hi * Wi, C), + make_tuple(WiStride, CStride)); + } + else + { + return make_naive_tensor_descriptor( + make_tuple(N, Di, Hi, Wi, C), + make_tuple(NStride, DiStride, HiStride, WiStride, CStride)); + } + } + + template ::type = false> + constexpr static auto + make_wei_grid_desc(const ck::index_t K, + const ck::index_t Z, + const ck::index_t Y, + const ck::index_t X, + const ck::index_t C, + const std::array& weights_strides) + { + const auto CStride = Number<1>{}; + const auto KStride = weights_strides[1]; + return make_naive_tensor_descriptor(make_tuple(K, Z * Y * X * C), + make_tuple(KStride, CStride)); + } + template ::type = false> static auto MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N( - ck::index_t N, - ck::index_t K, - ck::index_t C, - std::array input_spatial_lengths, - std::array filter_spatial_lengths, - std::array output_spatial_lengths, - std::array conv_filter_strides, - std::array conv_filter_dilations, - std::array input_left_pads, - std::array input_right_pads, - ck::index_t batch_k) + const ck::index_t N, + const ck::index_t K, + const ck::index_t C, + const std::array& input_spatial_lengths, + const std::array& filter_spatial_lengths, + const std::array& output_spatial_lengths, + const std::array& /* input_strides */, + const std::array& /* weights_strides */, + const std::array& /* output_strides */, + const std::array& conv_filter_strides, + const std::array& conv_filter_dilations, + const std::array& input_left_pads, + const std::array& input_right_pads, + const ck::index_t batch_k) { using namespace ck; @@ -248,6 +378,9 @@ struct DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Xdl_CShuffle const index_t GemmM = K; const index_t GemmN = C * X; + const auto PadGemmM = (MPerBlock - GemmM % MPerBlock) % MPerBlock; + const auto PadGemmN = (NPerBlock - GemmN % NPerBlock) % NPerBlock; + const index_t GemmKBatch = batch_k; const index_t GemmK0 = math::integer_divide_ceil(GemmKTotal, GemmK1Number * K0PerBlock * GemmKBatch) * @@ -282,14 +415,14 @@ struct DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Xdl_CShuffle const auto in_gemmkpad_gemmn_grid_desc = transform_tensor_descriptor( in_gemmktotal_gemmn_grid_desc, make_tuple(make_right_pad_transform(GemmKTotal, GemmKPad - GemmKTotal), - make_pass_through_transform(GemmM)), + make_pass_through_transform(GemmN)), make_tuple(Sequence<0>{}, Sequence<1>{}), make_tuple(Sequence<0>{}, Sequence<1>{})); const auto in_gemmkbatch_gemmk0_gemmn_gemmk1_grid_desc = transform_tensor_descriptor( in_gemmkpad_gemmn_grid_desc, make_tuple(make_unmerge_transform(make_tuple(GemmKBatch, GemmK0, GemmK1Number)), - make_pass_through_transform(GemmM)), + make_pass_through_transform(GemmN)), make_tuple(Sequence<0>{}, Sequence<1>{}), make_tuple(Sequence<0, 1, 3>{}, Sequence<2>{})); @@ -366,25 +499,56 @@ struct DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Xdl_CShuffle const auto wei_gemmm_gemmn_grid_desc = make_naive_tensor_descriptor_packed(make_tuple(K, X * C)); - return make_tuple(out_gemmkbatch_gemmk0_gemmm_gemmk1_grid_desc, - in_gemmkbatch_gemmk0_gemmn_gemmk1_grid_desc, - wei_gemmm_gemmn_grid_desc); + // Padd + const auto out_gemmkbatch_gemmk0_gemmm_gemmk1_pad_grid_desc = + transform_tensor_descriptor( + out_gemmkbatch_gemmk0_gemmm_gemmk1_grid_desc, + make_tuple(make_pass_through_transform(GemmKBatch), + make_pass_through_transform(GemmK0), + make_right_pad_transform(GemmM, PadGemmM), + make_pass_through_transform(GemmK1Number)), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{})); + + const auto in_gemmkbatch_gemmk0_gemmn_gemmk1_pad_grid_desc = + transform_tensor_descriptor( + in_gemmkbatch_gemmk0_gemmn_gemmk1_grid_desc, + make_tuple(make_pass_through_transform(GemmKBatch), + make_pass_through_transform(GemmK0), + make_right_pad_transform(GemmN, PadGemmN), + make_pass_through_transform(GemmK1Number)), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{})); + + const auto wei_gemmm_gemmn_pad_grid_desc = + transform_tensor_descriptor(wei_gemmm_gemmn_grid_desc, + make_tuple(make_right_pad_transform(GemmM, PadGemmM), + make_right_pad_transform(GemmN, PadGemmN)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + + return make_tuple(out_gemmkbatch_gemmk0_gemmm_gemmk1_pad_grid_desc, + in_gemmkbatch_gemmk0_gemmn_gemmk1_pad_grid_desc, + wei_gemmm_gemmn_pad_grid_desc); } } template ::type = false> static auto MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N( - ck::index_t N, - ck::index_t K, - ck::index_t C, - std::array input_spatial_lengths, - std::array filter_spatial_lengths, - std::array output_spatial_lengths, - std::array conv_filter_strides, - std::array conv_filter_dilations, - std::array input_left_pads, - std::array input_right_pads, - ck::index_t batch_k) + const ck::index_t N, + const ck::index_t K, + const ck::index_t C, + const std::array& input_spatial_lengths, + const std::array& filter_spatial_lengths, + const std::array& output_spatial_lengths, + const std::array& input_strides, + const std::array& weights_strides, + const std::array& output_strides, + const std::array& conv_filter_strides, + const std::array& conv_filter_dilations, + const std::array& input_left_pads, + const std::array& input_right_pads, + const ck::index_t batch_k) { using namespace ck; @@ -413,21 +577,25 @@ struct DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Xdl_CShuffle const index_t GemmM = K; const index_t GemmN = C * X * Y; + const auto PadGemmM = (MPerBlock - GemmM % MPerBlock) % MPerBlock; + const auto PadGemmN = (NPerBlock - GemmN % NPerBlock) % NPerBlock; + const index_t GemmKBatch = batch_k; const index_t GemmK0 = math::integer_divide_ceil(GemmKTotal, GemmK1Number * K0PerBlock * GemmKBatch) * K0PerBlock; const index_t GemmKPad = GemmKBatch * GemmK0 * GemmK1Number; + const auto out_grid_desc = make_out_grid_desc(N, Ho, Wo, K, output_strides); + const auto in_grid_desc = make_in_grid_desc(N, Hi, Wi, C, input_strides); + const auto wei_grid_desc = make_wei_grid_desc(K, Y, X, C, weights_strides); + if constexpr(ConvBackwardWeightSpecialization == ConvolutionBackwardWeightSpecialization::Filter1x1Stride1Pad0) { // A: output tensor - const auto out_gemmktotal_gemmm_grid_desc = - make_naive_tensor_descriptor_packed(make_tuple(N * Ho * Wo, K)); - const auto out_gemmkpad_gemmm_grid_desc = transform_tensor_descriptor( - out_gemmktotal_gemmm_grid_desc, + out_grid_desc, make_tuple(make_right_pad_transform(GemmKTotal, GemmKPad - GemmKTotal), make_pass_through_transform(GemmM)), make_tuple(Sequence<0>{}, Sequence<1>{}), @@ -441,41 +609,29 @@ struct DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Xdl_CShuffle make_tuple(Sequence<0, 1, 3>{}, Sequence<2>{})); // B: input tensor - const auto in_gemmktotal_gemmn_grid_desc = - make_naive_tensor_descriptor_packed(make_tuple(N * Hi * Wi, C)); - const auto in_gemmkpad_gemmn_grid_desc = transform_tensor_descriptor( - in_gemmktotal_gemmn_grid_desc, + in_grid_desc, make_tuple(make_right_pad_transform(GemmKTotal, GemmKPad - GemmKTotal), - make_pass_through_transform(GemmM)), + make_pass_through_transform(GemmN)), make_tuple(Sequence<0>{}, Sequence<1>{}), make_tuple(Sequence<0>{}, Sequence<1>{})); const auto in_gemmkbatch_gemmk0_gemmn_gemmk1_grid_desc = transform_tensor_descriptor( in_gemmkpad_gemmn_grid_desc, make_tuple(make_unmerge_transform(make_tuple(GemmKBatch, GemmK0, GemmK1Number)), - make_pass_through_transform(GemmM)), + make_pass_through_transform(GemmN)), make_tuple(Sequence<0>{}, Sequence<1>{}), make_tuple(Sequence<0, 1, 3>{}, Sequence<2>{})); - // C: weight tensor - const auto wei_gemmm_gemmn_grid_desc = - make_naive_tensor_descriptor_packed(make_tuple(K, Y * X * C)); - return make_tuple(out_gemmkbatch_gemmk0_gemmm_gemmk1_grid_desc, in_gemmkbatch_gemmk0_gemmn_gemmk1_grid_desc, - wei_gemmm_gemmn_grid_desc); + wei_grid_desc); } else { - const auto out_gemmktotal_gemmm_grid_desc = - make_naive_tensor_descriptor_packed(make_tuple(N * Ho * Wo, K)); - const auto in_n_hi_wi_c_grid_desc = - make_naive_tensor_descriptor_packed(make_tuple(N, Hi, Wi, C)); - // A: output tensor const auto out_gemmkpad_gemmm_grid_desc = transform_tensor_descriptor( - out_gemmktotal_gemmm_grid_desc, + out_grid_desc, make_tuple(make_right_pad_transform(GemmKTotal, GemmKPad - GemmKTotal), make_pass_through_transform(GemmM)), make_tuple(Sequence<0>{}, Sequence<1>{}), @@ -490,7 +646,7 @@ struct DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Xdl_CShuffle // B: input tensor const auto in_n_hip_wip_c_grid_desc = transform_tensor_descriptor( - in_n_hi_wi_c_grid_desc, + in_grid_desc, make_tuple(make_pass_through_transform(N), make_pad_transform(Hi, InLeftPadH, InRightPadH), make_pad_transform(Wi, InLeftPadW, InRightPadW), @@ -529,29 +685,56 @@ struct DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Xdl_CShuffle make_tuple(Sequence<0>{}, Sequence<1>{}), make_tuple(Sequence<0, 1, 3>{}, Sequence<2>{})); - // C: weight tensor - const auto wei_gemmm_gemmn_grid_desc = - make_naive_tensor_descriptor_packed(make_tuple(K, Y * X * C)); - - return make_tuple(out_gemmkbatch_gemmk0_gemmm_gemmk1_grid_desc, - in_gemmkbatch_gemmk0_gemmn_gemmk1_grid_desc, - wei_gemmm_gemmn_grid_desc); + // Padd + const auto out_gemmkbatch_gemmk0_gemmm_gemmk1_pad_grid_desc = + transform_tensor_descriptor( + out_gemmkbatch_gemmk0_gemmm_gemmk1_grid_desc, + make_tuple(make_pass_through_transform(GemmKBatch), + make_pass_through_transform(GemmK0), + make_right_pad_transform(GemmM, PadGemmM), + make_pass_through_transform(GemmK1Number)), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{})); + + const auto in_gemmkbatch_gemmk0_gemmn_gemmk1_pad_grid_desc = + transform_tensor_descriptor( + in_gemmkbatch_gemmk0_gemmn_gemmk1_grid_desc, + make_tuple(make_pass_through_transform(GemmKBatch), + make_pass_through_transform(GemmK0), + make_right_pad_transform(GemmN, PadGemmN), + make_pass_through_transform(GemmK1Number)), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{})); + + const auto wei_gemmm_gemmn_pad_grid_desc = + transform_tensor_descriptor(wei_grid_desc, + make_tuple(make_right_pad_transform(GemmM, PadGemmM), + make_right_pad_transform(GemmN, PadGemmN)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + + return make_tuple(out_gemmkbatch_gemmk0_gemmm_gemmk1_pad_grid_desc, + in_gemmkbatch_gemmk0_gemmn_gemmk1_pad_grid_desc, + wei_gemmm_gemmn_pad_grid_desc); } } template ::type = false> static auto MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N( - ck::index_t N, - ck::index_t K, - ck::index_t C, - std::array input_spatial_lengths, - std::array filter_spatial_lengths, - std::array output_spatial_lengths, - std::array conv_filter_strides, - std::array conv_filter_dilations, - std::array input_left_pads, - std::array input_right_pads, - ck::index_t batch_k) + const ck::index_t N, + const ck::index_t K, + const ck::index_t C, + const std::array& input_spatial_lengths, + const std::array& filter_spatial_lengths, + const std::array& output_spatial_lengths, + const std::array& input_strides, + const std::array& weights_strides, + const std::array& output_strides, + const std::array& conv_filter_strides, + const std::array& conv_filter_dilations, + const std::array& input_left_pads, + const std::array& input_right_pads, + const ck::index_t batch_k) { using namespace ck; @@ -587,21 +770,25 @@ struct DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Xdl_CShuffle const index_t GemmM = K; const index_t GemmN = C * Z * X * Y; + const auto PadGemmM = (MPerBlock - GemmM % MPerBlock) % MPerBlock; + const auto PadGemmN = (NPerBlock - GemmN % NPerBlock) % NPerBlock; + const index_t GemmKBatch = batch_k; const index_t GemmK0 = math::integer_divide_ceil(GemmKTotal, GemmK1Number * K0PerBlock * GemmKBatch) * K0PerBlock; const index_t GemmKPad = GemmKBatch * GemmK0 * GemmK1Number; + const auto out_grid_desc = make_out_grid_desc(N, Do, Ho, Wo, K, output_strides); + const auto in_grid_desc = make_in_grid_desc(N, Di, Hi, Wi, C, input_strides); + const auto wei_grid_desc = make_wei_grid_desc(K, Z, Y, X, C, weights_strides); + if constexpr(ConvBackwardWeightSpecialization == ConvolutionBackwardWeightSpecialization::Filter1x1Stride1Pad0) { // A: output tensor - const auto out_gemmktotal_gemmm_grid_desc = - make_naive_tensor_descriptor_packed(make_tuple(N * Do * Ho * Wo, K)); - const auto out_gemmkpad_gemmm_grid_desc = transform_tensor_descriptor( - out_gemmktotal_gemmm_grid_desc, + out_grid_desc, make_tuple(make_right_pad_transform(GemmKTotal, GemmKPad - GemmKTotal), make_pass_through_transform(GemmM)), make_tuple(Sequence<0>{}, Sequence<1>{}), @@ -615,41 +802,29 @@ struct DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Xdl_CShuffle make_tuple(Sequence<0, 1, 3>{}, Sequence<2>{})); // B: input tensor - const auto in_gemmktotal_gemmn_grid_desc = - make_naive_tensor_descriptor_packed(make_tuple(N * Di * Hi * Wi, C)); - const auto in_gemmkpad_gemmn_grid_desc = transform_tensor_descriptor( - in_gemmktotal_gemmn_grid_desc, + in_grid_desc, make_tuple(make_right_pad_transform(GemmKTotal, GemmKPad - GemmKTotal), - make_pass_through_transform(GemmM)), + make_pass_through_transform(GemmN)), make_tuple(Sequence<0>{}, Sequence<1>{}), make_tuple(Sequence<0>{}, Sequence<1>{})); const auto in_gemmkbatch_gemmk0_gemmn_gemmk1_grid_desc = transform_tensor_descriptor( in_gemmkpad_gemmn_grid_desc, make_tuple(make_unmerge_transform(make_tuple(GemmKBatch, GemmK0, GemmK1Number)), - make_pass_through_transform(GemmM)), + make_pass_through_transform(GemmN)), make_tuple(Sequence<0>{}, Sequence<1>{}), make_tuple(Sequence<0, 1, 3>{}, Sequence<2>{})); - // C: weight tensor - const auto wei_gemmm_gemmn_grid_desc = - make_naive_tensor_descriptor_packed(make_tuple(K, Z * Y * X * C)); - return make_tuple(out_gemmkbatch_gemmk0_gemmm_gemmk1_grid_desc, in_gemmkbatch_gemmk0_gemmn_gemmk1_grid_desc, - wei_gemmm_gemmn_grid_desc); + wei_grid_desc); } else { - const auto out_gemmktotal_gemmm_grid_desc = - make_naive_tensor_descriptor_packed(make_tuple(N * Do * Ho * Wo, K)); - const auto in_n_di_hi_wi_c_grid_desc = - make_naive_tensor_descriptor_packed(make_tuple(N, Di, Hi, Wi, C)); - // A: output tensor const auto out_gemmkpad_gemmm_grid_desc = transform_tensor_descriptor( - out_gemmktotal_gemmm_grid_desc, + out_grid_desc, make_tuple(make_right_pad_transform(GemmKTotal, GemmKPad - GemmKTotal), make_pass_through_transform(GemmM)), make_tuple(Sequence<0>{}, Sequence<1>{}), @@ -664,7 +839,7 @@ struct DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Xdl_CShuffle // B: input tensor const auto in_n_dip_hip_wip_c_grid_desc = transform_tensor_descriptor( - in_n_di_hi_wi_c_grid_desc, + in_grid_desc, make_tuple(make_pass_through_transform(N), make_pad_transform(Di, InLeftPadD, InRightPadD), make_pad_transform(Hi, InLeftPadH, InRightPadH), @@ -712,44 +887,110 @@ struct DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Xdl_CShuffle make_tuple(Sequence<0>{}, Sequence<1>{}), make_tuple(Sequence<0, 1, 3>{}, Sequence<2>{})); - // C: weight tensor - const auto wei_gemmm_gemmn_grid_desc = - make_naive_tensor_descriptor_packed(make_tuple(K, Z * Y * X * C)); - - return make_tuple(out_gemmkbatch_gemmk0_gemmm_gemmk1_grid_desc, - in_gemmkbatch_gemmk0_gemmn_gemmk1_grid_desc, - wei_gemmm_gemmn_grid_desc); + // Padd + const auto out_gemmkbatch_gemmk0_gemmm_gemmk1_pad_grid_desc = + transform_tensor_descriptor( + out_gemmkbatch_gemmk0_gemmm_gemmk1_grid_desc, + make_tuple(make_pass_through_transform(GemmKBatch), + make_pass_through_transform(GemmK0), + make_right_pad_transform(GemmM, PadGemmM), + make_pass_through_transform(GemmK1Number)), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{})); + + const auto in_gemmkbatch_gemmk0_gemmn_gemmk1_pad_grid_desc = + transform_tensor_descriptor( + in_gemmkbatch_gemmk0_gemmn_gemmk1_grid_desc, + make_tuple(make_pass_through_transform(GemmKBatch), + make_pass_through_transform(GemmK0), + make_right_pad_transform(GemmN, PadGemmN), + make_pass_through_transform(GemmK1Number)), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{})); + + const auto wei_gemmm_gemmn_pad_grid_desc = + transform_tensor_descriptor(wei_grid_desc, + make_tuple(make_right_pad_transform(GemmM, PadGemmM), + make_right_pad_transform(GemmN, PadGemmN)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + + return make_tuple(out_gemmkbatch_gemmk0_gemmm_gemmk1_pad_grid_desc, + in_gemmkbatch_gemmk0_gemmn_gemmk1_pad_grid_desc, + wei_gemmm_gemmn_pad_grid_desc); } } // function end template ::type = false> static auto GetABCGridDesc() { - return MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N<1>( - 1, 1, 1, {1}, {1}, {1}, {1}, {1}, {1}, {1}, 1); + const ck::index_t dim = 1; + const ck::index_t batch = 1; + const std::array lengths{1}; + const std::array strides{1, 1, 1, 1}; + const std::array params{1}; + return MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N<1>(dim, + dim, + dim, + lengths, + lengths, + lengths, + strides, + strides, + strides, + params, + params, + params, + params, + batch); } template ::type = false> static auto GetABCGridDesc() { - return MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N<2>( - 1, 1, 1, {1, 1}, {1, 1}, {1, 1}, {1, 1}, {1, 1}, {1, 1}, {1, 1}, 1); + const ck::index_t dim = 1; + const ck::index_t batch = 1; + const std::array lengths{1, 1}; + const std::array strides{1, 1, 1, 1, 1}; + const std::array params{1, 1}; + return MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N<2>(dim, + dim, + dim, + lengths, + lengths, + lengths, + strides, + strides, + strides, + params, + params, + params, + params, + batch); } template ::type = false> static auto GetABCGridDesc() { - return MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N<3>(1, - 1, - 1, - {1, 1, 1}, - {1, 1, 1}, - {1, 1, 1}, - {1, 1, 1}, - {1, 1, 1}, - {1, 1, 1}, - {1, 1, 1}, - 1); + const ck::index_t dim = 1; + const ck::index_t batch = 1; + const std::array lengths{1, 1, 1}; + const std::array strides{1, 1, 1, 1, 1, 1}; + const std::array params{1, 1, 1}; + return MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N<3>(dim, + dim, + dim, + lengths, + lengths, + lengths, + strides, + strides, + strides, + params, + params, + params, + params, + batch); } // type convert descs @@ -863,19 +1104,18 @@ struct DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Xdl_CShuffle Argument(const InDataType* p_in_grid, WeiDataType* p_wei_grid, const OutDataType* p_out_grid, - ck::index_t G, - ck::index_t N, - ck::index_t K, - ck::index_t C, - std::array input_spatial_lengths, - std::array filter_spatial_lengths, - std::array output_spatial_lengths, - std::array conv_filter_strides, - std::array conv_filter_dilations, - std::array input_left_pads, - std::array input_right_pads, - ck::index_t M01, - ck::index_t N01, + const std::array& a_g_n_c_wis_lengths, // input + const std::array& a_g_n_c_wis_strides, + const std::array& b_g_k_c_xs_lengths, // weight + const std::array& b_g_k_c_xs_strides, + const std::array& e_g_n_k_wos_lengths, // output + const std::array& e_g_n_k_wos_strides, + const std::array& conv_filter_strides, + const std::array& conv_filter_dilations, + const std::array& input_left_pads, + const std::array& input_right_pads, + const ck::index_t M01, + const ck::index_t N01, InElementwiseOperation in_element_op, WeiElementwiseOperation wei_element_op, OutElementwiseOperation out_element_op, @@ -894,25 +1134,40 @@ struct DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Xdl_CShuffle a_element_op_{out_element_op}, b_element_op_{in_element_op}, c_element_op_{wei_element_op}, - Conv_G_{G}, - Conv_N_{N}, - Conv_K_{K}, - Conv_C_{C}, - output_spatial_lengths_{output_spatial_lengths}, - filter_spatial_lengths_{filter_spatial_lengths}, + Conv_G_{a_g_n_c_wis_lengths[0]}, + Conv_N_{a_g_n_c_wis_lengths[1]}, + Conv_K_{b_g_k_c_xs_lengths[1]}, + Conv_C_{a_g_n_c_wis_lengths[2]}, + input_spatial_lengths_{}, + filter_spatial_lengths_{}, + output_spatial_lengths_{}, conv_filter_strides_{conv_filter_strides}, input_left_pads_{input_left_pads}, input_right_pads_{input_right_pads}, k_batch_{split_k} { + constexpr index_t spatial_offset = 3; + std::copy(begin(a_g_n_c_wis_lengths) + spatial_offset, + end(a_g_n_c_wis_lengths), + begin(input_spatial_lengths_)); + std::copy(begin(b_g_k_c_xs_lengths) + spatial_offset, + end(b_g_k_c_xs_lengths), + begin(filter_spatial_lengths_)); + std::copy(begin(e_g_n_k_wos_lengths) + spatial_offset, + end(e_g_n_k_wos_lengths), + begin(output_spatial_lengths_)); + const auto descs = DeviceOp::MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N( - N, - K, - C, - input_spatial_lengths, - filter_spatial_lengths, - output_spatial_lengths, + Conv_N_, + Conv_K_, + Conv_C_, + input_spatial_lengths_, + filter_spatial_lengths_, + output_spatial_lengths_, + a_g_n_c_wis_strides, + b_g_k_c_xs_strides, + e_g_n_k_wos_strides, conv_filter_strides, conv_filter_dilations, input_left_pads, @@ -927,22 +1182,12 @@ struct DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Xdl_CShuffle GridwiseGemm::MakeCBlockClusterAdaptor(c_grid_desc_m_n_, M01, N01, k_batch_); // A/B/C Batch Stride - compute_ptr_offset_of_batch_.BatchStrideA_ = - N * K * - std::accumulate(begin(output_spatial_lengths), - end(output_spatial_lengths), - index_t{1}, - std::multiplies<>{}); - compute_ptr_offset_of_batch_.BatchStrideB_ = - N * C * - std::accumulate(begin(input_spatial_lengths), - end(input_spatial_lengths), - index_t{1}, - std::multiplies<>{}); + compute_ptr_offset_of_batch_.BatchStrideA_ = e_g_n_k_wos_strides[0]; + compute_ptr_offset_of_batch_.BatchStrideB_ = a_g_n_c_wis_strides[0]; compute_ptr_offset_of_batch_.BatchStrideC_ = - K * C * - std::accumulate(begin(filter_spatial_lengths), - end(filter_spatial_lengths), + Conv_K_ * Conv_C_ * + std::accumulate(begin(filter_spatial_lengths_), + end(filter_spatial_lengths_), index_t{1}, std::multiplies<>{}); @@ -977,16 +1222,17 @@ struct DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Xdl_CShuffle WeiElementwiseOperation c_element_op_; // for checking IsSupportedArgument() - index_t Conv_G_; - index_t Conv_N_; - index_t Conv_K_; - index_t Conv_C_; - std::array output_spatial_lengths_; + const index_t Conv_G_; + const index_t Conv_N_; + const index_t Conv_K_; + const index_t Conv_C_; + std::array input_spatial_lengths_; std::array filter_spatial_lengths_; - std::array conv_filter_strides_; - std::array input_left_pads_; - std::array input_right_pads_; - index_t k_batch_; + std::array output_spatial_lengths_; + const std::array& conv_filter_strides_; + const std::array& input_left_pads_; + const std::array& input_right_pads_; + const index_t k_batch_; }; // Invoker @@ -1091,6 +1337,32 @@ struct DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Xdl_CShuffle static bool IsSupportedArgument(const Argument& arg) { + if constexpr(NDimSpatial == 1) + { + if constexpr(!is_GNWK_GKXC_GNWC) + { + return false; + } + } + else if constexpr(NDimSpatial == 2) + { + if constexpr(!(is_NHWGK_GKYXC_NHWGC || is_GNHWK_GKYXC_GNHWC)) + { + return false; + } + } + else if constexpr(NDimSpatial == 3) + { + if constexpr(!(is_NDHWGK_GKZYXC_NDHWGC || is_GNDHWK_GKZYXC_GNDHWC)) + { + return false; + } + } + else + { + return false; + } + if constexpr(ConvBackwardWeightSpecialization == ConvolutionBackwardWeightSpecialization::Filter1x1Stride1Pad0) { @@ -1131,35 +1403,34 @@ struct DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Xdl_CShuffle return IsSupportedArgument(*dynamic_cast(p_arg)); } - static auto MakeArgument(const InDataType* p_in_grid, - WeiDataType* p_wei_grid, - const OutDataType* p_out_grid, - ck::index_t G, - ck::index_t N, - ck::index_t K, - ck::index_t C, - std::array input_spatial_lengths, - std::array filter_spatial_lengths, - std::array output_spatial_lengths, - std::array conv_filter_strides, - std::array conv_filter_dilations, - std::array input_left_pads, - std::array input_right_pads, - InElementwiseOperation in_element_op, - WeiElementwiseOperation wei_element_op, - OutElementwiseOperation out_element_op, - ck::index_t split_k) + static auto + MakeArgument(const InDataType* p_in_grid, + WeiDataType* p_wei_grid, + const OutDataType* p_out_grid, + const std::array& a_g_n_c_wis_lengths, // input + const std::array& a_g_n_c_wis_strides, + const std::array& b_g_k_c_xs_lengths, // weight + const std::array& b_g_k_c_xs_strides, + const std::array& e_g_n_k_wos_lengths, // output + const std::array& e_g_n_k_wos_strides, + const std::array& conv_filter_strides, + const std::array& conv_filter_dilations, + const std::array& input_left_pads, + const std::array& input_right_pads, + InElementwiseOperation in_element_op, + WeiElementwiseOperation wei_element_op, + OutElementwiseOperation out_element_op, + const ck::index_t split_k) { return Argument{p_in_grid, p_wei_grid, p_out_grid, - G, - N, - K, - C, - input_spatial_lengths, - filter_spatial_lengths, - output_spatial_lengths, + a_g_n_c_wis_lengths, // input + a_g_n_c_wis_strides, + b_g_k_c_xs_lengths, // weight + b_g_k_c_xs_strides, + e_g_n_k_wos_lengths, // output + e_g_n_k_wos_strides, conv_filter_strides, conv_filter_dilations, input_left_pads, @@ -1178,32 +1449,30 @@ struct DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Xdl_CShuffle MakeArgumentPointer(const void* p_in_grid, void* p_wei_grid, const void* p_out_grid, - ck::index_t G, - ck::index_t N, - ck::index_t K, - ck::index_t C, - std::array input_spatial_lengths, - std::array filter_spatial_lengths, - std::array output_spatial_lengths, - std::array conv_filter_strides, - std::array conv_filter_dilations, - std::array input_left_pads, - std::array input_right_pads, + const std::array& a_g_n_c_wis_lengths, // input + const std::array& a_g_n_c_wis_strides, + const std::array& b_g_k_c_xs_lengths, // weight + const std::array& b_g_k_c_xs_strides, + const std::array& e_g_n_k_wos_lengths, // output + const std::array& e_g_n_k_wos_strides, + const std::array& conv_filter_strides, + const std::array& conv_filter_dilations, + const std::array& input_left_pads, + const std::array& input_right_pads, InElementwiseOperation in_element_op, WeiElementwiseOperation wei_element_op, OutElementwiseOperation out_element_op, - ck::index_t split_k) override + const ck::index_t split_k) override { return std::make_unique(static_cast(p_in_grid), static_cast(p_wei_grid), static_cast(p_out_grid), - G, - N, - K, - C, - input_spatial_lengths, - filter_spatial_lengths, - output_spatial_lengths, + a_g_n_c_wis_lengths, // input + a_g_n_c_wis_strides, + b_g_k_c_xs_lengths, // weight + b_g_k_c_xs_strides, + e_g_n_k_wos_lengths, // output + e_g_n_k_wos_strides, conv_filter_strides, conv_filter_dilations, input_left_pads, @@ -1226,7 +1495,7 @@ struct DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Xdl_CShuffle auto str = std::stringstream(); // clang-format off - str << "DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Xdl_CShuffle" + str << "DeviceGroupedConvBwdWeight_Xdl_CShuffle" << "<" << BlockSize << ", " << MPerBlock << ", " diff --git a/include/ck/tensor_operation/gpu/device/device_grouped_conv_fwd_dl_multiple_d_nhwc_kyxc_nhwk.hpp b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_dl_multiple_d_nhwc_kyxc_nhwk.hpp similarity index 99% rename from include/ck/tensor_operation/gpu/device/device_grouped_conv_fwd_dl_multiple_d_nhwc_kyxc_nhwk.hpp rename to include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_dl_multiple_d_nhwc_kyxc_nhwk.hpp index 6180333350ad414d5bcdaac78e7892546757d605..8b22bd209043e900cd594bcac7923833e9a26536 100644 --- a/include/ck/tensor_operation/gpu/device/device_grouped_conv_fwd_dl_multiple_d_nhwc_kyxc_nhwk.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_dl_multiple_d_nhwc_kyxc_nhwk.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #pragma once @@ -136,7 +136,7 @@ __global__ void { #if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx906__) || defined(__gfx1030__) || \ defined(__gfx90a__) || defined(__gfx908__) || defined(__gfx940__) || defined(__gfx1100__) || \ - defined(__gfx1101__) || defined(__gfx1102__)) + defined(__gfx1101__) || defined(__gfx1102__) || defined(__gfx941__) || defined(__gfx942__)) // offset base pointer for each work-group const index_t num_blocks_per_batch = __builtin_amdgcn_readfirstlane(get_grid_size() / batch_count); @@ -381,8 +381,8 @@ struct DeviceGroupedConvFwdDlMultipleD_NHWC_KYXC_NHWK } // desc for problem definition - using AGridDesc_AK0_M_AK1 = remove_cvref_t({}, {}, {}, {}, {}, {}, {}, {}, {}, {}))>; + using AGridDesc_AK0_M_AK1 = remove_cvref_t( + {}, {}, {}, {}, {}, {}, {}, {}, {}, {}))>; using BGridDesc_BK0_N_BK1 = remove_cvref_t({}, {}))>; using DsGridDesc_M_N = remove_cvref_t; @@ -713,7 +713,8 @@ struct DeviceGroupedConvFwdDlMultipleD_NHWC_KYXC_NHWK if(!(ck::get_device_name() == "gfx906" || ck::get_device_name() == "gfx1030" || ck::get_device_name() == "gfx90a" || ck::get_device_name() == "gfx908" || ck::get_device_name() == "gfx940" || ck::get_device_name() == "gfx1100" || - ck::get_device_name() == "gfx1101" || ck::get_device_name() == "gfx1102")) + ck::get_device_name() == "gfx1101" || ck::get_device_name() == "gfx1102" || + ck::get_device_name() == "gfx941" || ck::get_device_name() == "gfx942")) { return false; } diff --git a/include/ck/tensor_operation/gpu/device/device_grouped_conv_fwd_dl_nhwc_kyxc_nhwk.hpp b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_dl_nhwc_kyxc_nhwk.hpp similarity index 99% rename from include/ck/tensor_operation/gpu/device/device_grouped_conv_fwd_dl_nhwc_kyxc_nhwk.hpp rename to include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_dl_nhwc_kyxc_nhwk.hpp index 80c864c8366af105421d3aadf46f620f9cde17e9..f18fbcfe4b51031009f6329baac2bc1ac209e1fa 100644 --- a/include/ck/tensor_operation/gpu/device/device_grouped_conv_fwd_dl_nhwc_kyxc_nhwk.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_dl_nhwc_kyxc_nhwk.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #pragma once @@ -320,8 +320,8 @@ struct DeviceGroupedConvFwdDl_NHWC_KYXC_NHWK : public DeviceGroupedConvFwd({}, {}, {}, {}, {}, {}, {}, {}, {}, {}))>; + using AGridDesc_AK0_M_AK1 = remove_cvref_t( + {}, {}, {}, {}, {}, {}, {}, {}, {}, {}))>; using BGridDesc_BK0_N_BK1 = remove_cvref_t({}, {}))>; using CGridDesc_M_N = remove_cvref_t({}, {}))>; diff --git a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_d_multiple_r.hpp b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_d_multiple_r.hpp index 03185d5b1d29186f2b2814257eacfe90da6c5e02..face627e1fb85d0621823803eabc5409b0e73d61 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_d_multiple_r.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_d_multiple_r.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_d_multiple_r_xdl_cshuffle.hpp b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_d_multiple_r_xdl_cshuffle.hpp index de40d71293038b281d1b4c119140883e1c38ac7f..caa18b709cea604930c3c81c80cdbebc5d9cfaa1 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_d_multiple_r_xdl_cshuffle.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_d_multiple_r_xdl_cshuffle.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #pragma once @@ -156,7 +156,7 @@ __global__ void const ComputePtrOffsetOfBatch compute_ptr_offset_of_batch) { #if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \ - defined(__gfx940__)) + defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__)) const index_t num_blocks_per_batch = __builtin_amdgcn_readfirstlane(get_grid_size() / batch_count); const index_t g_idx = __builtin_amdgcn_readfirstlane(get_block_1d_id() / num_blocks_per_batch); @@ -446,8 +446,8 @@ struct DeviceGroupedConvFwdMultipleDMultipleR_Xdl_CShuffle return GetPaddedRGridDescriptor(r_grid_desc_mraw, NHoWo); } - using AGridDesc_M_K = remove_cvref_t({}, {}, {}, {}, {}, {}, {}, {}, {}, {}))>; + using AGridDesc_M_K = remove_cvref_t( + {}, {}, {}, {}, {}, {}, {}, {}, {}, {}))>; using BGridDesc_N_K = remove_cvref_t({}, {}))>; using EGridDesc_M_N = remove_cvref_t({}, {}))>; using RGridDesc_M = remove_cvref_t({}, {}))>; @@ -507,10 +507,12 @@ struct DeviceGroupedConvFwdMultipleDMultipleR_Xdl_CShuffle RThreadTransferDstScalarPerVector_MPerBlock, LoopSched>; - using AGridDesc_AK0_M_AK1 = remove_cvref_t; - using BGridDesc_BK0_N_BK1 = remove_cvref_t; + using AGridDesc_AK0_M_AK1 = + remove_cvref_t; + using BGridDesc_BK0_N_BK1 = + remove_cvref_t; using Block2ETileMap = typename GridwiseGemm::DefaultBlock2ETileMap; @@ -811,7 +813,8 @@ struct DeviceGroupedConvFwdMultipleDMultipleR_Xdl_CShuffle return false; } } - else if(get_device_name() == "gfx90a" || get_device_name() == "gfx940") + else if(get_device_name() == "gfx90a" || get_device_name() == "gfx940" || + get_device_name() == "gfx941" || get_device_name() == "gfx942") { if constexpr(!(is_same_v || is_same_v || is_same_v || is_same_v)) diff --git a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_d_wmma_cshuffle.hpp b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_d_wmma_cshuffle.hpp index 9d4b68c0b6429aacd16106d1088db5e1c1f74d4f..1d73a723871501897ef8a81a23a94b07a0d5faad 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_d_wmma_cshuffle.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_d_wmma_cshuffle.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #pragma once @@ -245,8 +245,8 @@ struct DeviceGroupedConvFwdMultipleD_Wmma_CShuffle } // desc for problem definition - using AGridDesc_M_K = remove_cvref_t({}, {}, {}, {}, {}, {}, {}, {}, {}, {}))>; + using AGridDesc_M_K = remove_cvref_t( + {}, {}, {}, {}, {}, {}, {}, {}, {}, {}))>; using BGridDesc_N_K = remove_cvref_t({}, {}))>; using DsGridDesc_M_N = remove_cvref_t; using EGridDesc_M_N = remove_cvref_t({}, {}))>; @@ -599,7 +599,7 @@ struct DeviceGroupedConvFwdMultipleD_Wmma_CShuffle // check if it's 1x1, stride=1 conv for(index_t i = 0; i < NDimSpatial; ++i) { - const index_t X = arg.b_g_k_c_xs_lengths_[i + 2]; + const index_t X = arg.b_g_k_c_xs_lengths_[i + 3]; const index_t ConvStride = arg.conv_filter_strides_[i]; const index_t LeftPad = arg.input_left_pads_[i]; const index_t RightPad = arg.input_right_pads_[i]; @@ -616,7 +616,7 @@ struct DeviceGroupedConvFwdMultipleD_Wmma_CShuffle // check if it's 1x1 conv for(index_t i = 0; i < NDimSpatial; ++i) { - const index_t X = arg.b_g_k_c_xs_lengths_[i + 2]; + const index_t X = arg.b_g_k_c_xs_lengths_[i + 3]; const index_t LeftPad = arg.input_left_pads_[i]; const index_t RightPad = arg.input_right_pads_[i]; diff --git a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_d_xdl_cshuffle.hpp b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_d_xdl_cshuffle.hpp index 02458bf02a39870118b846bf99fc774d5d73ee4c..bcef5c3b621406aaa8d514ce371f02952ac4818f 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_d_xdl_cshuffle.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_d_xdl_cshuffle.hpp @@ -136,7 +136,7 @@ __global__ void const ComputePtrOffsetOfBatch compute_ptr_offset_of_batch) { #if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \ - defined(__gfx940__)) + defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__)) // offset base pointer for each work-group const index_t num_blocks_per_batch = __builtin_amdgcn_readfirstlane(get_grid_size() / batch_count); @@ -361,15 +361,19 @@ struct DeviceGroupedConvFwdMultipleD_Xdl_CShuffle } // desc for problem definition - using AGridDesc_M_K = remove_cvref_t({}, {}, {}, {}, {}, {}, {}, {}, {}, {}))>; + using AGridDesc_M_K = remove_cvref_t( + {}, {}, {}, {}, {}, {}, {}, {}, {}, {}))>; using BGridDesc_N_K = remove_cvref_t({}, {}))>; using DsGridDesc_M_N = remove_cvref_t; using EGridDesc_M_N = remove_cvref_t({}, {}))>; + using ComputeDataType = ADataType; + // GridwiseGemm using GridwiseGemm = GridwiseGemmMultipleD_xdl_cshuffle< ADataType, // TODO: distinguish A/B datatype + BDataType, + ComputeDataType, AccDataType, CShuffleDataType, DsDataType, @@ -412,14 +416,18 @@ struct DeviceGroupedConvFwdMultipleD_Xdl_CShuffle LoopSched>; // desc for blockwise copy - using AGridDesc_AK0_M_AK1 = remove_cvref_t; - using BGridDesc_BK0_N_BK1 = remove_cvref_t; - using DsGridDesc_MBlock_MPerBlock_NBlock_NPerBlock = remove_cvref_t; - using EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock = remove_cvref_t; + using AGridDesc_AK0_M_AK1 = + remove_cvref_t; + using BGridDesc_BK0_N_BK1 = + remove_cvref_t; + using DsGridDesc_MBlock_MPerBlock_NBlock_NPerBlock = remove_cvref_t< + decltype(GridwiseGemm::MakeDsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( + DsGridDesc_M_N{}))>; + using EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock = + remove_cvref_t; // block-to-e-tile map using Block2ETileMap = @@ -685,7 +693,8 @@ struct DeviceGroupedConvFwdMultipleD_Xdl_CShuffle return false; } } - else if(get_device_name() == "gfx90a" || get_device_name() == "gfx940") + else if(get_device_name() == "gfx90a" || get_device_name() == "gfx940" || + get_device_name() == "gfx941" || get_device_name() == "gfx942") { if constexpr(!(is_same_v || is_same_v || is_same_v || is_same_v)) diff --git a/include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_multiple_d_dl.hpp b/include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_multiple_d_dl.hpp index e88bf8ed758cb685feb0b51ee62cdbf7c7189ad7..0190b3cee6b44af7656c34be866c9ab0855f7b0d 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_multiple_d_dl.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_multiple_d_dl.hpp @@ -1,6 +1,6 @@ #pragma once // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #pragma once @@ -41,7 +41,7 @@ __global__ void { #if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx906__) || defined(__gfx908__) || \ defined(__gfx90a__) || defined(__gfx1030__) || defined(__gfx1100__) || defined(__gfx1101__) || \ - defined(__gfx1102__) || defined(__gfx940__)) + defined(__gfx1102__) || defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__)) __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()]; const index_t block_id = get_block_1d_id(); @@ -597,10 +597,12 @@ struct DeviceGroupedGemmMultipleD_Dl : public DeviceGroupedGemm; using EGridDesc_M_N = decltype(MakeEGridDescriptor_M_N(1, 1, 1)); + using ComputeDataType = ADataType; + // GridwiseGemm using GridwiseGemm = GridwiseGemmMultipleD_xdl_cshuffle< ADataType, // TODO: distinguish A/B datatype + BDataType, + ComputeDataType, AccDataType, CShuffleDataType, DsDataType, @@ -272,14 +276,18 @@ struct DeviceGroupedGemm_Xdl : public DeviceGroupedGemm; - using AGridDesc_AK0_M_AK1 = remove_cvref_t; - using BGridDesc_BK0_N_BK1 = remove_cvref_t; - using DsGridDesc_MBlock_MPerBlock_NBlock_NPerBlock = remove_cvref_t; - using EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock = remove_cvref_t; + using AGridDesc_AK0_M_AK1 = + remove_cvref_t; + using BGridDesc_BK0_N_BK1 = + remove_cvref_t; + using DsGridDesc_MBlock_MPerBlock_NBlock_NPerBlock = remove_cvref_t< + decltype(GridwiseGemm::MakeDsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( + DsGridDesc_M_N{}))>; + using EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock = + remove_cvref_t; struct GroupedGemmBlock2ETileMap { @@ -548,11 +556,12 @@ struct DeviceGroupedGemm_Xdl : public DeviceGroupedGemm(arg.gemm_desc_kernel_arg_.size()) + arg.skipped_group_count_) != arg.group_count_) { diff --git a/include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_xdl_fixed_nk.hpp b/include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_xdl_fixed_nk.hpp new file mode 100644 index 0000000000000000000000000000000000000000..8cea09ee54799957e24212966758a444451d5bdd --- /dev/null +++ b/include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_xdl_fixed_nk.hpp @@ -0,0 +1,836 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include +#include + +#include "ck/utility/common_header.hpp" +#include "ck/tensor_description/tensor_descriptor.hpp" +#include "ck/tensor_description/tensor_descriptor_helper.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/device_grouped_gemm_fixed_nk.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_d_xdl_splitk_cshuffle.hpp" +#include "ck/host_utility/device_prop.hpp" +#include "ck/host_utility/kernel_launch.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { + +template +__global__ void +#if CK_USE_LAUNCH_BOUNDS + __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) +#endif + kernel_grouped_gemm_xdl_fixed_nk(const void CK_CONSTANT_ADDRESS_SPACE* gemm_descs_const, + uint32_t* barrier_count, + const index_t barrier_size_grp, + const index_t group_count, + const index_t grid_size_grp, + const index_t KBatch, + const AElementwiseOperation a_element_op, + const BElementwiseOperation b_element_op, + const CDEElementwiseOperation c_element_op) +{ +#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \ + defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__)) + __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()]; + + const index_t block_id = get_block_1d_id(); + + const auto gemm_desc_ptr = + reinterpret_cast(cast_pointer_to_generic_address_space(gemm_descs_const)); + + const index_t group_id = block_id / grid_size_grp; + + if(group_id >= group_count) + return; + + const index_t M = gemm_desc_ptr[group_id].M; + const index_t N = gemm_desc_ptr[group_id].N; + const index_t K = gemm_desc_ptr[group_id].K; + + if(M * N * K == 0) + return; + + const auto StrideA = gemm_desc_ptr[group_id].StrideA; + const auto StrideB = gemm_desc_ptr[group_id].StrideB; + const auto StrideDs = gemm_desc_ptr[group_id].StrideDs; + const auto StrideE = gemm_desc_ptr[group_id].StrideE; + + const auto e_grid_desc_m_n = + GridwiseGemm::template MakeEGridDescriptor_M_N(M, N, StrideE); + + const index_t BlockStart = group_id * grid_size_grp; + + const auto local_b2e_tile_map = Block2ETileMap{e_grid_desc_m_n, KBatch}; + + const auto local_grid_size = local_b2e_tile_map.CalculateGridSize(e_grid_desc_m_n); + + constexpr auto NumDTensor = DsDataType::Size(); + + using DsGridPointer = decltype(GridwiseGemm::MakeDsGridPointer()); + + DsGridPointer p_ds_grid_; + + static_for<0, NumDTensor, 1>{}([&](auto i) { + using DDataType = remove_cvref_t>; + // D pointer + p_ds_grid_(i) = static_cast(gemm_desc_ptr[group_id].p_ds_grid[i]); + }); + + index_t id_off = 0; + index_t id_local = get_block_1d_id() - BlockStart; + + const index_t mn_blocks = local_grid_size / KBatch; + + while(id_local < local_grid_size) + { + const auto block_2_etile_map = + GroupedGemmBlock2ETileMap(local_b2e_tile_map, BlockStart, id_off); + + auto barrier_count_finished = + barrier_count + group_id * barrier_size_grp + id_local % mn_blocks; + + GridwiseGemm::template Run(gemm_desc_ptr[group_id].p_a_grid, + gemm_desc_ptr[group_id].p_b_grid, + p_ds_grid_, + gemm_desc_ptr[group_id].p_e_grid, + p_shared, + barrier_count_finished, + a_element_op, + b_element_op, + c_element_op, + M, + N, + K, + StrideA, + StrideB, + StrideDs, + StrideE, + KBatch, + block_2_etile_map); + + id_off += grid_size_grp; + id_local += grid_size_grp; + } +#else + ignore = gemm_descs_const; + ignore = barrier_count; + ignore = barrier_size_grp; + ignore = group_count; + ignore = grid_size_grp; + ignore = KBatch; + ignore = a_element_op; + ignore = b_element_op; + ignore = c_element_op; +#endif +} + +template +struct DeviceGroupedGemm_Xdl_Fixed_NK : public DeviceGroupedGemmFixedNK +{ + using DeviceOp = DeviceGroupedGemm_Xdl_Fixed_NK; + + static constexpr index_t NumDTensor = DsDataType::Size(); + + static constexpr auto I0 = Number<0>{}; + static constexpr auto I1 = Number<1>{}; + static constexpr auto I2 = Number<2>{}; + + // GridwiseGemm + using GridwiseGemm = GridwiseGemmMultipleD_xdl_splitk_cshuffle< + ADataType, // TODO: distinguish A/B datatype + AccDataType, + CShuffleDataType, + DsDataType, + EDataType, + AElementwiseOperation, + BElementwiseOperation, + CDEElementwiseOperation, + NumPrefetch, // NumGemmKPrefetchStage + BlockSize, + MPerBlock, + NPerBlock, + KPerBlock, + AK1, + BK1, + MPerXDL, + NPerXDL, + MXdlPerWave, + NXdlPerWave, + ABlockTransferThreadClusterLengths_K0_M_K1, + ABlockTransferThreadClusterArrangeOrder, + ABlockTransferSrcAccessOrder, + ABlockTransferSrcVectorDim, + ABlockTransferSrcScalarPerVector, + ABlockTransferDstScalarPerVector_K1, + false, // AThreadTransferSrcResetCoordinateAfterRun, + ABlockLdsExtraM, + BBlockTransferThreadClusterLengths_K0_N_K1, + BBlockTransferThreadClusterArrangeOrder, + BBlockTransferSrcAccessOrder, + BBlockTransferSrcVectorDim, + BBlockTransferSrcScalarPerVector, + BBlockTransferDstScalarPerVector_K1, + false, // BThreadTransferSrcResetCoordinateAfterRun, + BBlockLdsExtraN, + CShuffleMXdlPerWavePerShuffle, + CShuffleNXdlPerWavePerShuffle, + CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, + CDEBlockTransferScalarPerVector_NPerBlock, + LoopSched>; + + template + struct OffsettedBlockToCTileMapMLoops + { + using underlying_type = UnderlyingBlockToCTileMap; + + __host__ __device__ OffsettedBlockToCTileMapMLoops( + UnderlyingBlockToCTileMap block_to_ctile_map, index_t block_start, index_t id_off = 0) + { + block_to_ctile_map_ = block_to_ctile_map; + block_start_ = block_start; + id_off_ = id_off; + } + + template + __host__ __device__ constexpr auto CalculateBottomIndex(const TopIdx& idx_top) const + { + auto idx_bot = block_to_ctile_map_.CalculateBottomIndex( + make_multi_index(idx_top[Number<0>{}] - block_start_ + id_off_)); + + return make_tuple(idx_bot[Number<0>{}], idx_bot[Number<1>{}], idx_bot[Number<2>{}]); + } + + template + __host__ __device__ bool ValidCTileIndex(const CTileIdx& c_tile_idx, + const CTileDim& c_tile_dim) const + { + return block_to_ctile_map_.ValidCTileIndex(c_tile_idx, c_tile_dim); + } + + template + __host__ bool CheckValidity(const CGridDesc_M_N& c_grid_desc_m_n) const + { + return block_to_ctile_map_.CheckValidity(c_grid_desc_m_n); + } + + template + __host__ constexpr index_t CalculateGridSize(const CGridDesc_M_N& c_grid_desc_m_n) const + { + return block_to_ctile_map_.CalculateGridSize(c_grid_desc_m_n); + } + + UnderlyingBlockToCTileMap block_to_ctile_map_; + index_t block_start_; + index_t id_off_; + }; + + template + struct BlockToCTileMap_KBatch_M00_N0_M01Adapt_MLoops + { + static constexpr auto I0 = Number<0>{}; + static constexpr auto I1 = Number<1>{}; + + __host__ __device__ BlockToCTileMap_KBatch_M00_N0_M01Adapt_MLoops() = default; + + __host__ __device__ BlockToCTileMap_KBatch_M00_N0_M01Adapt_MLoops( + const BlockToCTileMap_KBatch_M00_N0_M01Adapt_MLoops&) = default; + __host__ __device__ BlockToCTileMap_KBatch_M00_N0_M01Adapt_MLoops( + BlockToCTileMap_KBatch_M00_N0_M01Adapt_MLoops&&) = default; + __host__ __device__ BlockToCTileMap_KBatch_M00_N0_M01Adapt_MLoops& + operator=(const BlockToCTileMap_KBatch_M00_N0_M01Adapt_MLoops&) = default; + __host__ __device__ BlockToCTileMap_KBatch_M00_N0_M01Adapt_MLoops& + operator=(BlockToCTileMap_KBatch_M00_N0_M01Adapt_MLoops&&) = default; + + __host__ __device__ BlockToCTileMap_KBatch_M00_N0_M01Adapt_MLoops(index_t M, + index_t N, + index_t KBatch, + index_t M01 = 8) + : M_(M), N_(N), KBatch_(KBatch), M01_(M01) + { + } + + template + __host__ __device__ BlockToCTileMap_KBatch_M00_N0_M01Adapt_MLoops( + const CGridDesc_M_N& c_grid_desc_m_n, index_t KBatch, index_t M01 = 8) + : BlockToCTileMap_KBatch_M00_N0_M01Adapt_MLoops( + c_grid_desc_m_n.GetLength(I0), c_grid_desc_m_n.GetLength(I1), KBatch, M01) + { + } + + __host__ __device__ constexpr index_t CalculateGridSize(index_t M, index_t N) const + { + const auto M0 = math::integer_divide_ceil(M, MPerBlock); + const auto N0 = math::integer_divide_ceil(N, NPerBlock); + + return M0 * N0 * KBatch_; + } + + template + __host__ __device__ constexpr index_t + CalculateGridSize(const CGridDesc_M_N& c_grid_desc_m_n) const + { + return CalculateGridSize(c_grid_desc_m_n.GetLength(I0), c_grid_desc_m_n.GetLength(I1)); + } + + template + __host__ bool CheckValidity(const CGridDesc_M_N& /* c_grid_desc_m_n */) const + { + return true; + } + + template + __host__ __device__ constexpr auto CalculateBottomIndex(const TopIdx& idx_top) const + { + auto block_1d_id = idx_top[I0]; + + const auto M0 = math::integer_divide_ceil(M_, MPerBlock_); + const auto N0 = math::integer_divide_ceil(N_, NPerBlock_); + + block_1d_id = block_1d_id % (M0 * N0 * KBatch_); // hide groups + + const index_t idx_ksplit = block_1d_id / (M0 * N0); + block_1d_id = block_1d_id % (M0 * N0); + + index_t idx_N0 = block_1d_id % N0; + index_t idx_M0 = block_1d_id / N0; + + const auto M01_adapt = (idx_M0 < M0 - M0 % M01_) ? M01_ : M0 % M01_; + + index_t idx_M00 = idx_M0 / M01_; + index_t idx_M01 = idx_M0 % M01_; + index_t idx_N0_M01_local = idx_N0 + idx_M01 * N0; + + return make_tuple(idx_ksplit, + idx_N0_M01_local % M01_adapt + idx_M00 * M01_, + idx_N0_M01_local / M01_adapt); + } + + template + __host__ __device__ bool ValidCTileIndex(const CTileIdx& /* c_tile_idx */, + const CTileDim& /* c_tile_dim */) const + { + return true; // always valid provided that user gets grid size from CalculateGridSize() + } + + private: + index_t M_; + index_t N_; + index_t KBatch_; + index_t M01_; + }; + + using Block2ETileMap = BlockToCTileMap_KBatch_M00_N0_M01Adapt_MLoops; + using GroupedGemmBlock2ETileMap = OffsettedBlockToCTileMapMLoops; + + struct GemmBiasTransKernelArg + { + // pointers + const void* a_ptr_; + const void* b_ptr_; + std::array ds_ptr_; + void* e_ptr_; + + index_t M_, N_, K_; + index_t StrideA_, StrideB_; + std::array StrideDs_; + index_t StrideE_; + }; + + // Argument + struct Argument : public BaseArgument + { + + void UpdateKBatch(index_t k_batch) + { + k_batch_ = k_batch; + + if(k_batch_ < 1) + { + + throw std::runtime_error("wrong! k_batch must be > 0"); + } + + const index_t AverM = math::integer_divide_ceil(sum_of_m, group_count_); + + const index_t StrideE = gemm_desc_kernel_arg_[0].StrideE_; + const index_t N = gemm_desc_kernel_arg_[0].N_; + + const auto e_grid_desc_m_n = + GridwiseGemm::template MakeEGridDescriptor_M_N( + AverM, N, StrideE); + + const auto local_b2c_tile_map = Block2ETileMap{e_grid_desc_m_n, k_batch_}; + + grid_size_grp_ = local_b2c_tile_map.CalculateGridSize(e_grid_desc_m_n); + + grid_size_ = grid_size_grp_ * group_count_; + } + + Argument(std::vector&, + std::vector&, + std::vector>&, + std::vector&, + std::vector& gemm_descs, + AElementwiseOperation a_element_op, + BElementwiseOperation b_element_op, + CDEElementwiseOperation c_element_op) + : a_element_op_{a_element_op}, b_element_op_{b_element_op}, c_element_op_{c_element_op} + { + grid_size_ = 0; + + k_batch_ = 1; + + grouped_gemm_kernel_args_dev = nullptr; + + group_count_ = ck::type_convert(gemm_descs.size()); + + gemm_desc_kernel_arg_.reserve(group_count_); + + index_t group_id = 0; + + sum_of_m = gemm_descs[0].M_; + const index_t AverM = math::integer_divide_ceil(sum_of_m, group_count_); + const index_t N = gemm_descs[0].N_; + const index_t K = gemm_descs[0].K_; + + for(std::size_t i = 0; i < gemm_descs.size(); i++) + { + if(sum_of_m != gemm_descs[i].M_ || N != gemm_descs[i].N_ || K != gemm_descs[i].K_) + { + throw std::runtime_error("wrong! M/N/K is not identical"); + } + + a_mtx_mraw_kraw_.emplace_back(sum_of_m, K); + b_mtx_nraw_kraw_.emplace_back(N, K); + + const index_t StrideA = gemm_descs[i].stride_A_; + const index_t StrideB = gemm_descs[i].stride_B_; + const index_t StrideE = gemm_descs[i].stride_C_; + + // pointer + std::array p_ds_grid; + + static_for<0, NumDTensor, 1>{}([&](auto j) { p_ds_grid[j] = nullptr; }); + + std::array StrideDs; + + static_for<0, NumDTensor, 1>{}([&](auto j) { + // using DLayout = remove_cvref_t>; + + if(gemm_descs[i].stride_Ds_.size() != NumDTensor) + { + throw std::runtime_error( + "wrong! gemm_descs[i].stride_Ds_.size() does not match NumDTensor"); + } + + StrideDs[j] = gemm_descs[i].stride_Ds_[j]; + }); + + const auto e_grid_desc_m_n = + GridwiseGemm::template MakeEGridDescriptor_M_N( + AverM, N, StrideE); + + // block-to-e-tile map + const auto local_b2c_tile_map = Block2ETileMap{e_grid_desc_m_n, k_batch_}; + + grid_size_grp_ = local_b2c_tile_map.CalculateGridSize(e_grid_desc_m_n); + + if(group_id * grid_size_grp_ != grid_size_) + { + throw std::runtime_error("wrong! grid_size_grp_ is not identical!"); + } + + grid_size_ += grid_size_grp_; + + // check block-to-E-tile + if(!local_b2c_tile_map.CheckValidity(e_grid_desc_m_n)) + { + throw std::runtime_error("wrong! block_2_etile_map validation failed"); + } + + if(!GridwiseGemm:: + template CheckValidity( + AverM, N, K, StrideA, StrideB, StrideDs, StrideE, 1)) + { + throw std::runtime_error( + "wrong! GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3 has invalid setting"); + } + + gemm_desc_kernel_arg_.push_back(GemmBiasTransKernelArg{ + nullptr, + nullptr, + p_ds_grid, + nullptr, + AverM, + N, + K, + StrideA, + StrideB, + StrideDs, + StrideE, + }); + + group_id++; + } + + const auto e_grid_desc_sum_m_n = + GridwiseGemm::template MakeEGridDescriptor_M_N( + sum_of_m, gemm_desc_kernel_arg_[0].N_, gemm_desc_kernel_arg_[0].StrideE_); + + const auto local_b2c_tile_map = Block2ETileMap{e_grid_desc_sum_m_n, 1}; + + barrier_size_grp_ = local_b2c_tile_map.CalculateGridSize(e_grid_desc_sum_m_n); + } + + // private: + index_t group_count_; + + AElementwiseOperation a_element_op_; + BElementwiseOperation b_element_op_; + CDEElementwiseOperation c_element_op_; + + std::vector gemm_desc_kernel_arg_; + std::vector> a_mtx_mraw_kraw_; + std::vector> b_mtx_nraw_kraw_; + + const void* grouped_gemm_kernel_args_dev; + + index_t grid_size_; + index_t grid_size_grp_; + index_t barrier_size_grp_; + index_t sum_of_m; + + index_t k_batch_; + }; + + // Invoker + struct Invoker : public BaseInvoker + { + using Argument = DeviceOp::Argument; + + float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{}) + { + bool has_main_k_block_loop = true; + + for(std::size_t i = 0; i < arg.gemm_desc_kernel_arg_.size(); i++) + { + const auto KPad = + GridwiseGemm::CalculateKPadded(arg.gemm_desc_kernel_arg_[i].K_, arg.k_batch_); + + if(GridwiseGemm::CalculateHasMainKBlockLoop(KPad) != has_main_k_block_loop) + { + throw std::runtime_error("wrong! not all gemm has_main_k_block_loop"); + } + } + + if(arg.grouped_gemm_kernel_args_dev == nullptr) + { + throw std::runtime_error("wrong! grouped_gemm_kernel_args_dev is nullpr"); + } + + float ave_time = 0; + + auto launch_kernel = [&](auto has_main_k_block_loop_, auto e_global_memory_operation_) { + const auto kernel = + kernel_grouped_gemm_xdl_fixed_nk, + GemmSpec, + ALayout, + BLayout, + DsLayout, + ELayout, + DsDataType, + Block2ETileMap, + GroupedGemmBlock2ETileMap, + AElementwiseOperation, + BElementwiseOperation, + CDEElementwiseOperation, + e_global_memory_operation_, + has_main_k_block_loop_>; + + return launch_and_time_kernel( + stream_config, + kernel, + dim3(arg.grid_size_), + dim3(BlockSize), + 0, + cast_pointer_to_constant_address_space(arg.grouped_gemm_kernel_args_dev), + reinterpret_cast(arg.p_workspace_), + arg.barrier_size_grp_, + arg.gemm_desc_kernel_arg_.size(), + arg.grid_size_grp_, + arg.k_batch_, + arg.a_element_op_, + arg.b_element_op_, + arg.c_element_op_); + }; + + constexpr auto AtomicAdd = InMemoryDataOperationEnum::AtomicAdd; + constexpr auto Set = InMemoryDataOperationEnum::Set; + + if(arg.k_batch_ > 1) + { + if(has_main_k_block_loop) + { + ave_time = + launch_kernel(integral_constant{}, + integral_constant{}); + } + else + { + ave_time = + launch_kernel(integral_constant{}, + integral_constant{}); + } + } + else + { + if(has_main_k_block_loop) + { + ave_time = launch_kernel(integral_constant{}, + integral_constant{}); + } + else + { + ave_time = launch_kernel(integral_constant{}, + integral_constant{}); + } + } + + return ave_time; + } + + // polymorphic + float Run(const BaseArgument* p_arg, + const StreamConfig& stream_config = StreamConfig{}) override + { + return Run(*dynamic_cast(p_arg), stream_config); + } + }; + + static bool IsSupportedArgument(const Argument& arg) + { + if(ck::type_convert(arg.gemm_desc_kernel_arg_.size()) != arg.group_count_) + { + return false; + } + + bool supported = true; + + // If we use padding we do not support vector loads for dimensions not divisible by vector + // load size. + if constexpr(GemmSpec != GemmSpecialization::Default) + { + // [A|B]BlockTransferSrcVectorDim value define dimension in the block {K0,M,K1} layout, + // thus we have to adapt it to the {M,K} or {N,K} layout. + const auto a_raw_vector_dim = ABlockTransferSrcVectorDim != 1 ? 1 : 0; + const auto b_raw_vector_dim = BBlockTransferSrcVectorDim != 1 ? 1 : 0; + + for(index_t i = 0; i < arg.group_count_; ++i) + { + const auto a_vector_dim = arg.a_mtx_mraw_kraw_[i].At(Number{}); + const auto b_vector_dim = arg.b_mtx_nraw_kraw_[i].At(Number{}); + + supported = supported & (a_vector_dim % ABlockTransferSrcScalarPerVector == 0); + supported = supported & (b_vector_dim % BBlockTransferSrcScalarPerVector == 0); + } + } + + return supported; + } + + // polymorphic + bool IsSupportedArgument(const BaseArgument* p_arg) override + { + return IsSupportedArgument(*dynamic_cast(p_arg)); + } + + static auto MakeArgument(std::vector& p_As, + std::vector& p_Bs, + std::vector>& p_Ds, + std::vector& p_Es, + std::vector gemm_descs, + AElementwiseOperation a_element_op, + BElementwiseOperation b_element_op, + CDEElementwiseOperation c_element_op) + { + return Argument{ + p_As, p_Bs, p_Ds, p_Es, gemm_descs, a_element_op, b_element_op, c_element_op}; + } + + static auto MakeInvoker() { return Invoker{}; } + + // polymorphic + std::unique_ptr + MakeArgumentPointer(std::vector& p_As, + std::vector& p_Bs, + std::vector>& p_Ds, + std::vector& p_Es, + std::vector& gemm_descs, + AElementwiseOperation a_element_op, + BElementwiseOperation b_element_op, + CDEElementwiseOperation c_element_op) override + { + return std::make_unique( + p_As, p_Bs, p_Ds, p_Es, gemm_descs, a_element_op, b_element_op, c_element_op); + } + + // polymorphic + std::unique_ptr MakeInvokerPointer() override + { + return std::make_unique(Invoker{}); + } + + // polymorphic + std::string GetTypeString() const override + { + auto str = std::stringstream(); + + // clang-format off + str << "DeviceGroupedGemm_Xdl_Fixed_NK" + << "<" + << BlockSize << ", " + << MPerBlock << ", " + << NPerBlock << ", " + << KPerBlock << ", " + << AK1 << ", " + << BK1 << ", " + << MPerXDL << ", " + << NPerXDL << ", " + << MXdlPerWave << ", " + << NXdlPerWave << ", " + << ABlockTransferSrcScalarPerVector << ", " + << BBlockTransferSrcScalarPerVector << ", " + << CShuffleMXdlPerWavePerShuffle << ", " + << CShuffleNXdlPerWavePerShuffle << ", " + << getGemmSpecializationString(GemmSpec) + << ">"; + // clang-format on + + return str.str(); + } + + static void SetDeviceKernelArgs(Argument& arg, const void* kernel_args) + { + arg.grouped_gemm_kernel_args_dev = kernel_args; + } + + // polymorphic + void SetDeviceKernelArgs(BaseArgument* p_arg, const void* kernel_args) const override + { + return SetDeviceKernelArgs(*dynamic_cast(p_arg), kernel_args); + } + + size_t GetWorkSpaceSize(const BaseArgument* p_arg) const override + { + auto arg = *dynamic_cast(p_arg); + + return arg.group_count_ * arg.barrier_size_grp_ * sizeof(uint32_t); + } + + size_t GetDeviceKernelArgSize(const BaseArgument* p_arg) const override + { + auto arg = *dynamic_cast(p_arg); + + return arg.group_count_ * sizeof(GroupedGemmKernelArgument); + } + + void SetWorkSpacePointer(BaseArgument* p_arg, void* p_workspace) const override + { + auto p_arg_ = dynamic_cast(p_arg); + p_arg_->p_workspace_ = p_workspace; + + hip_check_error(hipMemset(p_workspace, 0, GetWorkSpaceSize(p_arg))); + } + + static void SetKBatch(Argument& arg, index_t k_batch) { arg.UpdateKBatch(k_batch); } + + // polymorphic + void SetKBatch(BaseArgument* p_arg, index_t k_batch) const override + { + return SetKBatch(*dynamic_cast(p_arg), k_batch); + } +}; + +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_xdl_splitk_cshuffle.hpp b/include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_xdl_splitk_cshuffle.hpp index 54ad9eb065a440e114f34c2842e5efe21267a98d..41445aeaf076408adc7159359884803f0965e313 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_xdl_splitk_cshuffle.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_xdl_splitk_cshuffle.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #pragma once @@ -35,7 +35,7 @@ __global__ void const index_t group_count) { #if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \ - defined(__gfx940__)) + defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__)) constexpr index_t shared_size = GridwiseGemm::GetSharedMemoryNumberOfByte(); __shared__ uint8_t p_shared[shared_size]; @@ -114,7 +114,8 @@ template > && is_same_v>, @@ -142,7 +143,8 @@ struct DeviceGroupedGemmXdlSplitKCShuffle : public DeviceGroupedGemmSplitK; + PipelineVer>; using CGridDesc_M_N = typename GridwiseGemm::CGridDesc_M_N; using Block2ETileMapKSplit = @@ -406,10 +408,12 @@ struct DeviceGroupedGemmXdlSplitKCShuffle : public DeviceGroupedGemmSplitK(arg.gemm_kernel_args_.size()) + arg.skipped_group_count_) != arg.group_count_) { diff --git a/include/ck/tensor_operation/gpu/device/impl/device_image_to_column_impl.hpp b/include/ck/tensor_operation/gpu/device/impl/device_image_to_column_impl.hpp new file mode 100644 index 0000000000000000000000000000000000000000..19f126e66fe4c1b6b662eccbe09296f798bb80d3 --- /dev/null +++ b/include/ck/tensor_operation/gpu/device/impl/device_image_to_column_impl.hpp @@ -0,0 +1,408 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck/tensor_description/tensor_descriptor.hpp" +#include "ck/tensor_description/tensor_descriptor_helper.hpp" +#include "ck/tensor_operation/gpu/device/device_image_to_column.hpp" +#include "ck/tensor_operation/gpu/grid/gridwise_image_to_column.hpp" +#include "ck/host_utility/kernel_launch.hpp" +#include "ck/tensor_operation/gpu/grid/block_to_ctile_map.hpp" +#include "ck/tensor_operation/operator_transform/transform_conv_fwd_to_gemm.hpp" +#include "ck/tensor_operation/gpu/device/convolution_forward_specialization.hpp" +#include "ck/tensor_operation/gpu/device/matrix_padder.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/host_utility/io.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { + +template +__global__ void +#if CK_USE_LAUNCH_BOUNDS + __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) +#endif + kernel_image_to_column(const InputGridDesc in_grid_desc, + const InputDataType* __restrict__ p_in_global, + const OutputGridDesc out_grid_desc, + OutputDataType* __restrict__ p_out_global, + const Block2ETileMap block_2_tile_map) +{ +#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx906__) || defined(__gfx908__) || \ + defined(__gfx90a__) || defined(__gfx940__) || defined(__gfx1030__) || defined(__gfx1100__) || \ + defined(__gfx1101__) || defined(__gfx1102__) || defined(__gfx941__) || defined(__gfx942__)) + GridwiseImageToColumnKernel::Run( + in_grid_desc, p_in_global, out_grid_desc, p_out_global, block_2_tile_map); +#else + ignore = in_grid_desc; + ignore = p_in_global; + ignore = out_grid_desc; + ignore = p_out_global; + ignore = block_2_tile_map; +#endif +} + +// Image to column for input layout NDHWC: +// input : input image [N, Di, Hi, Wi, C], +// output : output image [N * Do * Ho * Wo, Z * Y * X * C] +template +struct DeviceImageToColumnImpl + : public DeviceImageToColumn +{ + + static constexpr auto I0 = Number<0>{}; + static constexpr auto I1 = Number<1>{}; + static constexpr auto I2 = Number<2>{}; + + static constexpr auto conv_to_gemm_transformer = + TransformConvFwdToGemm{}; + + static constexpr auto matrix_padder = + MatrixPadder{ + MPerBlock, 0 /* NPerBlock*/, KPerBlock}; + + // Use MakeADescriptor_M_K from grouped convolution forward + static auto + MakeInputDescriptor_M_K(const ck::index_t N, + const ck::index_t C, + const std::array& input_spatial_lengths, + const std::array& filter_spatial_lengths, + const std::array& output_spatial_lengths, + const std::array& input_g_n_c_wis_strides, + const std::array& conv_filter_strides, + const std::array& conv_filter_dilations, + const std::array& input_left_pads, + const std::array& input_right_pads) + { + std::array a_g_n_c_wis_lengths{1}; + std::array b_g_k_c_xs_lengths{1}; + std::array c_g_n_k_wos_lengths{1}; + + auto copy = [](const auto& x, auto& y, index_t dst_offset) { + std::copy(x.begin(), x.end(), y.begin() + dst_offset); + }; + + constexpr index_t spatial_offset = 3; + + copy(input_spatial_lengths, a_g_n_c_wis_lengths, spatial_offset); + copy(filter_spatial_lengths, b_g_k_c_xs_lengths, spatial_offset); + copy(output_spatial_lengths, c_g_n_k_wos_lengths, spatial_offset); + + // fill only significant values (C and N) + a_g_n_c_wis_lengths[I1] = N; + a_g_n_c_wis_lengths[I2] = C; + b_g_k_c_xs_lengths[I2] = C; + c_g_n_k_wos_lengths[I1] = N; + + const auto in_gemmmraw_gemmkraw_desc = + conv_to_gemm_transformer.template MakeADescriptor_M_K( + a_g_n_c_wis_lengths, + input_g_n_c_wis_strides, + b_g_k_c_xs_lengths, + {}, // not needed for A Descriptor + c_g_n_k_wos_lengths, + {}, // not needed for A Descriptor + conv_filter_strides, + conv_filter_dilations, + input_left_pads, + input_right_pads); + + const auto in_gemmm_gemmk_desc = + matrix_padder.PadADescriptor_M_K(in_gemmmraw_gemmkraw_desc); + return in_gemmm_gemmk_desc; + } + + static auto + MakeOutDescriptor_M_K(const ck::index_t N, + const ck::index_t C, + const std::array& filter_spatial_lengths, + const std::array& output_spatial_lengths, + const std::array& output_m_k_strides) + { + const index_t NDoHoWo = + N * ck::accumulate_n( + output_spatial_lengths.begin(), NDimSpatial, 1, std::multiplies<>()); + const index_t CZYX = + C * ck::accumulate_n( + filter_spatial_lengths.begin(), NDimSpatial, 1, std::multiplies<>()); + const auto desc_mraw_kraw = make_naive_tensor_descriptor( + make_tuple(NDoHoWo, CZYX), make_tuple(output_m_k_strides[I0], output_m_k_strides[I1])); + + const auto desc_m_k = matrix_padder.PadADescriptor_M_K(desc_mraw_kraw); + return desc_m_k; + } + + using InputGridDesc = + remove_cvref_t; + using OutputGridDesc = remove_cvref_t; + + using Block2ETileMap = remove_cvref_t< + decltype(BlockToCTileMap_M00_N0_M01Adapt( + OutputGridDesc{}))>; + + using GridwiseImageToColumnKernel = GridwiseImageToColumn; + + struct Argument : public BaseArgument + { + Argument(const void* p_in, // input image + void* p_out, // output image + const ck::index_t N, + const ck::index_t C, + const std::array& input_spatial_lengths, + const std::array& filter_spatial_lengths, + const std::array& output_spatial_lengths, + const std::array& input_g_n_c_wis_strides, + const std::array& output_m_k_strides, + const std::array& conv_filter_strides, + const std::array& conv_filter_dilations, + const std::array& input_left_pads, + const std::array& input_right_pads) + : C_(C), + X_(filter_spatial_lengths[NDimSpatial - I1]), + p_in_{static_cast(p_in)}, + p_out_{static_cast(p_out)}, + input_g_n_c_wis_strides_{input_g_n_c_wis_strides}, + conv_filter_strides_{conv_filter_strides}, + conv_filter_dilations_{conv_filter_dilations}, + input_left_pads_{input_left_pads}, + input_right_pads_{input_right_pads} + { + + in_grid_desc_m_k_ = MakeInputDescriptor_M_K(N, + C, + input_spatial_lengths, + filter_spatial_lengths, + output_spatial_lengths, + input_g_n_c_wis_strides, + + conv_filter_strides, + conv_filter_dilations, + input_left_pads, + input_right_pads); + + out_grid_desc_m_k_ = MakeOutDescriptor_M_K( + N, C, filter_spatial_lengths, output_spatial_lengths, output_m_k_strides); + } + + void Print() const + { + std::cout << in_grid_desc_m_k_ << std::endl; + std::cout << out_grid_desc_m_k_ << std::endl; + } + + const ck::index_t C_; + const ck::index_t X_; + + const InputDataType* p_in_; + OutputDataType* p_out_; + + const std::array& input_g_n_c_wis_strides_; + const std::array& conv_filter_strides_; + const std::array& conv_filter_dilations_; + const std::array& input_left_pads_; + const std::array& input_right_pads_; + + InputGridDesc in_grid_desc_m_k_; + OutputGridDesc out_grid_desc_m_k_; + }; + + struct Invoker : public BaseInvoker + { + float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{}) + { + if(stream_config.log_level_ > 0) + { + arg.Print(); + } + + const auto block_2_tile_map = + BlockToCTileMap_M00_N0_M01Adapt( + arg.out_grid_desc_m_k_); + const index_t grid_size = block_2_tile_map.CalculateGridSize(arg.out_grid_desc_m_k_); + const auto kernel = kernel_image_to_column; + + float elapsed_time = launch_and_time_kernel(stream_config, + kernel, + dim3(grid_size), + dim3(BlockSize), + 0, + arg.in_grid_desc_m_k_, + arg.p_in_, + arg.out_grid_desc_m_k_, + arg.p_out_, + block_2_tile_map); + return elapsed_time; + } + + float Run(const BaseArgument* p_arg, + const StreamConfig& stream_config = StreamConfig{}) override + { + return Run(*dynamic_cast(p_arg), stream_config); + } + }; + + bool IsSupportedArgument(const Argument& arg) + { + using namespace tensor_layout::convolution; + if(!(std::is_same_v || std::is_same_v || + std::is_same_v)) + { + return false; + } + if(!(NDimSpatial >= 1 && NDimSpatial <= 3)) + { + return false; + } + + const auto w_pad_left = arg.input_left_pads_[NDimSpatial - I1]; + const auto w_pad_right = arg.input_right_pads_[NDimSpatial - I1]; + const auto dilation_x = arg.conv_filter_dilations_[NDimSpatial - I1]; + const auto stride_x = arg.conv_filter_strides_[NDimSpatial - I1]; + bool is_w_packed = arg.input_g_n_c_wis_strides_[NDimSpatial + I2] == arg.C_; + bool is_c_packed = arg.input_g_n_c_wis_strides_[I2] == 1; + + // check vector acces with c not packed + if(!is_c_packed && ScalarPerVector != 1) + return false; + // check vector access of filter window row (only C if C is not packed) + if(!is_w_packed && arg.C_ % ScalarPerVector != 0) + return false; + // check vector access of filter window row (X * C) + if(arg.X_ * arg.C_ % ScalarPerVector != 0) + return false; + // check vector access of pads (w_pad_left/w_pad_right * C) + if(w_pad_left * arg.C_ % ScalarPerVector != 0 || + w_pad_right * arg.C_ % ScalarPerVector != 0) + return false; + // check vector access of with stride and pad + if((w_pad_left != 0 || w_pad_right != 0) && stride_x > 1 && arg.C_ % ScalarPerVector != 0) + return false; + // check vector access of with dilation + if(dilation_x > 1 && arg.C_ % ScalarPerVector != 0) + return false; + + return GridwiseImageToColumnKernel::CheckValidity(arg.in_grid_desc_m_k_, + arg.out_grid_desc_m_k_); + } + + bool IsSupportedArgument(const BaseArgument* p_arg) override + { + return IsSupportedArgument(*dynamic_cast(p_arg)); + } + + static auto MakeArgument(const void* p_in, // input image + void* p_out, // output image + const ck::index_t N, + const ck::index_t C, + const std::array& input_spatial_lengths, + const std::array& filter_spatial_lengths, + const std::array& output_spatial_lengths, + const std::array& input_g_n_c_wis_strides, + const std::array& output_m_k_strides, + const std::array& conv_filter_strides, + const std::array& conv_filter_dilations, + const std::array& input_left_pads, + const std::array& input_right_pads) + { + return Argument{static_cast(p_in), + static_cast(p_out), + N, + C, + input_spatial_lengths, + filter_spatial_lengths, + output_spatial_lengths, + input_g_n_c_wis_strides, + output_m_k_strides, + conv_filter_strides, + conv_filter_dilations, + input_left_pads, + input_right_pads}; + } + + static auto MakeInvoker() { return Invoker{}; } + + std::unique_ptr + MakeArgumentPointer(const void* p_in, // input image + void* p_out, // output image + const ck::index_t N, + const ck::index_t C, + const std::array& input_spatial_lengths, + const std::array& filter_spatial_lengths, + const std::array& output_spatial_lengths, + const std::array& input_g_n_c_wis_strides, + const std::array& output_m_k_strides, + const std::array& conv_filter_strides, + const std::array& conv_filter_dilations, + const std::array& input_left_pads, + const std::array& input_right_pads) override + { + return std::make_unique(static_cast(p_in), + static_cast(p_out), + N, + C, + input_spatial_lengths, + filter_spatial_lengths, + output_spatial_lengths, + input_g_n_c_wis_strides, + output_m_k_strides, + conv_filter_strides, + conv_filter_dilations, + input_left_pads, + input_right_pads); + } + + std::unique_ptr MakeInvokerPointer() override + { + return std::make_unique(Invoker{}); + } + + std::string GetTypeString() const override + { + auto str = std::stringstream(); + + // clang-format off + str << "DeviceImageToColumn" + << "<" + << BlockSize << ", " + << MPerBlock << ", " + << KPerBlock << ", " + << ScalarPerVector + << ">"; + // clang-format on + + return str.str(); + } +}; + +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/include/ck/tensor_operation/gpu/device/impl/device_max_pool_bwd_impl.hpp b/include/ck/tensor_operation/gpu/device/impl/device_max_pool_bwd_impl.hpp new file mode 100644 index 0000000000000000000000000000000000000000..e98a85defe95aabed5c977250048cae117b701d6 --- /dev/null +++ b/include/ck/tensor_operation/gpu/device/impl/device_max_pool_bwd_impl.hpp @@ -0,0 +1,325 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include +#include + +#include "ck/tensor_description/tensor_descriptor.hpp" +#include "ck/tensor_description/tensor_descriptor_helper.hpp" +#include "ck/tensor_operation/gpu/device/device_max_pool_bwd.hpp" +#include "ck/tensor_operation/gpu/grid/gridwise_put_element_1d.hpp" +#include "ck/tensor_operation/gpu/grid/gridwise_elementwise_1d.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" +#include "ck/host_utility/device_prop.hpp" +#include "ck/host_utility/kernel_launch.hpp" +#include "ck/host_utility/stream_utility.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { + +// output[indices] = input +template +struct DeviceMaxPoolBwdImpl : public DeviceMaxPoolBwd +{ + using DInDataType_AutomicAddPreCast = + conditional_t || is_same_v, + DInDataType, + float>; + + using PassThrough = ck::tensor_operation::element_wise::PassThrough; + using UnaryConvert = ck::tensor_operation::element_wise::UnaryConvert; + + static constexpr auto I0 = Number<0>{}; + + template + static auto PadDescriptor_M_1d(Desc_M desc_m, index_t loop_step) + { + const auto m = desc_m.GetLength(I0); + const auto pad = math::integer_least_multiple(m, loop_step) - m; + const auto desc_m_pad = + transform_tensor_descriptor(desc_m, + make_tuple(make_right_pad_transform(m, pad)), + make_tuple(Sequence<0>{}), + make_tuple(Sequence<0>{})); + return desc_m_pad; + } + + static auto MakeDescriptor_M(index_t length, index_t loop_step) + { + const auto desc_m = make_naive_tensor_descriptor_packed(make_tuple(length)); + return PadDescriptor_M_1d(desc_m, loop_step); + } + + using InOutGrid1dDesc = decltype(MakeDescriptor_M(1, 1)); + + using GridwisePutElementSet = GridwisePutElement_1D; + + using GridwisePutElementAtomicAdd = GridwisePutElement_1D; + + using GridwiseCasting = GridwiseElementwise_1D, + Tuple, + Tuple, + Tuple, + UnaryConvert, + InOutVectorSize, + Sequence, + Sequence>; + + struct Argument : public BaseArgument + { + Argument(const DOutDataType* p_dout, + const IndexDataType* p_indices, + DInDataType* p_din, + index_t dout_length, + index_t din_length, + const std::vector& window_lengths, + const std::vector& window_strides, + const std::vector& window_dilations) + : p_dout_{p_dout}, + p_indices_{p_indices}, + p_din_{p_din}, + dout_length_raw_{dout_length}, + din_length_raw_{din_length}, + blockSize_{256}, + windowOverlap_{false} + { + for(size_t i = 0; i < window_lengths.size(); ++i) + { + auto eff = (window_lengths.at(i) - 1) * window_dilations.at(i) + 1; + windowOverlap_ |= eff > window_strides.at(i); + } + } + + const DOutDataType* p_dout_; + const IndexDataType* p_indices_; + DInDataType* p_din_; + index_t dout_length_raw_; + index_t din_length_raw_; + index_t blockSize_; + bool windowOverlap_; + }; + + struct Invoker : public BaseInvoker + { + float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{}) + { + index_t gridSize = getAvailableComputeUnitCount(stream_config); + index_t loop_step = gridSize * arg.blockSize_ * InOutVectorSize; + InOutGrid1dDesc din_grid_desc = MakeDescriptor_M(arg.din_length_raw_, loop_step); + InOutGrid1dDesc dout_grid_desc = MakeDescriptor_M(arg.dout_length_raw_, loop_step); + + if constexpr(is_same_v || is_same_v) + { + hip_check_error(hipMemsetAsync(arg.p_din_, + 0, + arg.din_length_raw_ * sizeof(DInDataType), + stream_config.stream_id_)); + + if(arg.windowOverlap_) + { + const auto put_kernel = kernel_put_element_1d; + + return launch_and_time_kernel(stream_config, + put_kernel, + dim3(gridSize), + dim3(arg.blockSize_), + 0, + dout_grid_desc, + arg.p_dout_, + arg.p_indices_, + arg.p_din_, + PassThrough{}); + } + else + { + const auto put_kernel = kernel_put_element_1d; + + return launch_and_time_kernel(stream_config, + put_kernel, + dim3(gridSize), + dim3(arg.blockSize_), + 0, + dout_grid_desc, + arg.p_dout_, + arg.p_indices_, + arg.p_din_, + PassThrough{}); + } + } + else + { + if(arg.windowOverlap_) + { + if(arg.p_workspace_ == nullptr) + throw std::runtime_error("wrong! WorkSpace pointer has not been set"); + + hip_check_error( + hipMemsetAsync(arg.p_workspace_, + 0, + arg.din_length_raw_ * sizeof(DInDataType_AutomicAddPreCast), + stream_config.stream_id_)); + + const auto put_kernel = kernel_put_element_1d; + + const auto cast_kernel = + kernel_elementwise_1d, + Tuple, + Tuple, + Tuple, + UnaryConvert>; + + float elapsed_time = launch_and_time_kernel( + stream_config, + put_kernel, + dim3(gridSize), + dim3(arg.blockSize_), + 0, + dout_grid_desc, + arg.p_dout_, + arg.p_indices_, + static_cast(arg.p_workspace_), + PassThrough{}); + + elapsed_time += launch_and_time_kernel( + stream_config, + cast_kernel, + dim3(gridSize), + dim3(arg.blockSize_), + 0, + ck::make_tuple(din_grid_desc), + ck::make_tuple(din_grid_desc), + static_cast(arg.p_workspace_), + arg.p_din_, + UnaryConvert{}); + + return elapsed_time; + } + else + { + hip_check_error(hipMemsetAsync(arg.p_din_, + 0, + arg.din_length_raw_ * sizeof(DInDataType), + stream_config.stream_id_)); + + const auto put_kernel = kernel_put_element_1d; + + hip_check_error(hipMemsetAsync(arg.p_din_, + 0, + arg.din_length_raw_ * sizeof(DInDataType), + stream_config.stream_id_)); + + return launch_and_time_kernel(stream_config, + put_kernel, + dim3(gridSize), + dim3(arg.blockSize_), + 0, + dout_grid_desc, + arg.p_dout_, + arg.p_indices_, + arg.p_din_, + PassThrough{}); + } + } + } + + float Run(const BaseArgument* p_arg, + const StreamConfig& stream_config = StreamConfig{}) override + { + return Run(*dynamic_cast(p_arg), stream_config); + } + }; + + size_t GetWorkSpaceSize(const BaseArgument* pArg) const override + { + const Argument* pArg_ = dynamic_cast(pArg); + + bool needCast = pArg_->windowOverlap_ && + !(is_same_v || is_same_v); + + if(!needCast) + return 0; + else + return pArg_->din_length_raw_ * sizeof(DInDataType_AutomicAddPreCast); + }; + + bool IsSupportedArgument(const BaseArgument* p_arg) override + { + const Argument* pArg = dynamic_cast(p_arg); + if(pArg->din_length_raw_ % InOutVectorSize != 0 || + pArg->dout_length_raw_ % InOutVectorSize != 0) + { + return false; + } + return true; + } + + std::unique_ptr + MakeArgumentPointer(const void* p_dout, + const void* p_indices, + void* p_din, + index_t dout_length, + index_t din_length, + std::vector window_lengths, + std::vector window_strides, + std::vector window_dilations) override + { + // Assume p_dout, p_indices, p_din are packed memory space, dout_length and din_length are + // physical size of the packed tensor + return std::make_unique(static_cast(p_dout), + static_cast(p_indices), + static_cast(p_din), + dout_length, + din_length, + window_lengths, + window_strides, + window_dilations); + } + + std::unique_ptr MakeInvokerPointer() override + { + return std::make_unique(Invoker{}); + } +}; + +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/include/ck/tensor_operation/gpu/device/impl/device_multiple_reduce_multiblock.hpp b/include/ck/tensor_operation/gpu/device/impl/device_multiple_reduce_multiblock.hpp index b49e1096829695f7f7d5b0f651a20f7aebc03f6b..aec5a65ccf41eafcbe6ebdf790cc75b0adbf2966 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_multiple_reduce_multiblock.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_multiple_reduce_multiblock.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/include/ck/tensor_operation/gpu/device/impl/device_multiple_reduce_threadwise.hpp b/include/ck/tensor_operation/gpu/device/impl/device_multiple_reduce_threadwise.hpp index 17a96e9f6f6c8039485afcab82d5908d92a59a24..6d1d5c8e239de8f57df1baa82d4981e8f30e54b6 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_multiple_reduce_threadwise.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_multiple_reduce_threadwise.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/include/ck/tensor_operation/gpu/device/impl/device_normalization_impl.hpp b/include/ck/tensor_operation/gpu/device/impl/device_normalization_impl.hpp index 6a8037a3249c4e5b8ebc6a72696623f8d4022f06..ea0d805043e07fe3d818fbb7bf7ef42e952496a9 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_normalization_impl.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_normalization_impl.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/include/ck/tensor_operation/gpu/device/impl/device_normalization_splitk_impl.hpp b/include/ck/tensor_operation/gpu/device/impl/device_normalization_splitk_impl.hpp index 0026a8759363b88cbec59cf010d9cf5c94cad753..8b2b3c41bfdf3f4793263dd65fc5f6dc941eb13b 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_normalization_splitk_impl.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_normalization_splitk_impl.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/include/ck/tensor_operation/gpu/device/impl/device_permute_impl.hpp b/include/ck/tensor_operation/gpu/device/impl/device_permute_impl.hpp index 7b96373c0ff1fc86a79c4248db35dd5812f2c5b1..17dab08332166e23a93d5c8f25ecfec7556519ac 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_permute_impl.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_permute_impl.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/include/ck/tensor_operation/gpu/device/impl/device_pool2d_fwd_nhwc_nhwc.hpp b/include/ck/tensor_operation/gpu/device/impl/device_pool2d_fwd_nhwc_nhwc.hpp index 6933db68dccd83fe4e036a0609b7eaa7517a3902..c94c568c49aca819997cd8a1d32c536d7527a551 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_pool2d_fwd_nhwc_nhwc.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_pool2d_fwd_nhwc_nhwc.hpp @@ -1,18 +1,9 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #pragma once -#include -#include - -#include "ck/tensor_description/tensor_descriptor.hpp" -#include "ck/tensor_description/tensor_descriptor_helper.hpp" -#include "ck/tensor_operation/gpu/device/reduction_operator_mapping.hpp" -#include "ck/tensor_operation/gpu/device/device_pool_fwd.hpp" -#include "ck/tensor_operation/gpu/grid/gridwise_2d_reduction_threadwise.hpp" -#include "ck/host_utility/device_prop.hpp" -#include "ck/host_utility/kernel_launch.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_pool3d_fwd_ndhwc_ndhwc.hpp" namespace ck { namespace tensor_operation { @@ -30,255 +21,32 @@ template -struct DevicePool2dFwd_Input_N_Hi_Wi_C_Output_N_Ho_Wo_C - : public DevicePoolFwd<4, 2, InDataType, OutDataType, IndexDataType, ReduceOpId, OutputIndex> +struct DevicePool2dFwd_NHWC_NHWC : public DevicePool3dFwd_NDHWC_NDHWC { - static constexpr auto I0 = Number<0>{}; - static constexpr auto I1 = Number<1>{}; - static constexpr auto I2 = Number<2>{}; - static constexpr auto I3 = Number<3>{}; - static constexpr auto I4 = Number<4>{}; - static constexpr auto I5 = Number<5>{}; - - static constexpr index_t InOutRank = 4; - static constexpr index_t WindowRank = 2; - - using ReduceOperation = typename reduce_binary_operator::opType; - - using InElementwiseOperation = - typename reduce_unary_operator::InElementwiseOperation; - - using AccElementwiseOperation = - typename reduce_unary_operator::AccElementwiseOperation; - - static constexpr index_t InSrcOutDstVectorDim = - 0; // for NHWC, the dim C is the vector Dim for both input and output in memory, which is - // not reduced. - - static constexpr ck::index_t ReduceM_BlockTileSize = - ReduceMThreadClusterSize * ReduceMThreadSliceSize; - static constexpr ck::index_t ReduceK_BlockTileSize = - ReduceKThreadClusterSize * ReduceKThreadSliceSize; - - static auto MakeABGridDescriptor_A_M_K_B_M(ck::index_t N, - ck::index_t C, - std::vector input_spatial_lengths, - std::vector window_spatial_lengths, - std::vector output_spatial_lengths, - std::vector window_strides, - std::vector input_left_pads, - std::vector input_right_pads) - { - const index_t Hi = input_spatial_lengths[0]; - const index_t Wi = input_spatial_lengths[1]; - - const index_t Ho = output_spatial_lengths[0]; - const index_t Wo = output_spatial_lengths[1]; - - const index_t Y = window_spatial_lengths[0]; - const index_t X = window_spatial_lengths[1]; - - const index_t ConvStrideH = window_strides[0]; - const index_t ConvStrideW = window_strides[1]; - - const index_t InLeftPadH = input_left_pads[0]; - const index_t InLeftPadW = input_left_pads[1]; - - const index_t InRightPadH = input_right_pads[0]; - const index_t InRightPadW = input_right_pads[1]; - - const index_t ReduceMRaw = N * Ho * Wo * C; - const index_t ReduceMPad = - math::integer_least_multiple(ReduceMRaw, ReduceM_BlockTileSize) - ReduceMRaw; - - const index_t ReduceKRaw = Y * X; - const index_t ReduceKPad = - math::integer_least_multiple(ReduceKRaw, ReduceK_BlockTileSize) - ReduceKRaw; - - // A[ReduceM, ReduceK] - const auto in_grid_desc_n_hi_wi_c = - make_naive_tensor_descriptor_packed(make_tuple(N, Hi, Wi, C)); - - const auto in_grid_desc_n_hip_wip_c = transform_tensor_descriptor( - in_grid_desc_n_hi_wi_c, - make_tuple(make_pass_through_transform(N), - make_pad_transform(Hi, InLeftPadH, InRightPadH), - make_pad_transform(Wi, InLeftPadW, InRightPadW), - make_pass_through_transform(C)), - make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), - make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{})); - - const auto in_grid_desc_n_y_ho_x_wo_c = transform_tensor_descriptor( - in_grid_desc_n_hip_wip_c, - make_tuple(make_pass_through_transform(N), - make_embed_transform(make_tuple(Y, Ho), make_tuple(I1, ConvStrideH)), - make_embed_transform(make_tuple(X, Wo), make_tuple(I1, ConvStrideW)), - make_pass_through_transform(C)), - make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), - make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<3, 4>{}, Sequence<5>{})); - - const auto in_grid_desc_reducemraw_reducekraw = - transform_tensor_descriptor(in_grid_desc_n_y_ho_x_wo_c, - make_tuple(make_merge_transform(make_tuple(N, Ho, Wo, C)), - make_merge_transform(make_tuple(Y, X))), - make_tuple(Sequence<0, 2, 4, 5>{}, Sequence<1, 3>{}), - make_tuple(Sequence<0>{}, Sequence<1>{})); - - const auto in_grid_desc_reducem_reducek = transform_tensor_descriptor( - in_grid_desc_reducemraw_reducekraw, - make_tuple(make_right_pad_transform(ReduceMRaw, ReduceMPad), - make_right_pad_transform(ReduceKRaw, ReduceKPad)), - make_tuple(Sequence<0>{}, Sequence<1>{}), - make_tuple(Sequence<0>{}, Sequence<1>{})); - - // B[ReduceM] - const auto out_grid_desc_reducemraw = - make_naive_tensor_descriptor_packed(make_tuple(N * Ho * Wo * C)); - - const auto out_grid_desc_reducem = transform_tensor_descriptor( - out_grid_desc_reducemraw, - make_tuple(make_right_pad_transform(ReduceMRaw, ReduceMPad)), - make_tuple(Sequence<0>{}), - make_tuple(Sequence<0>{})); - - return make_tuple(in_grid_desc_reducem_reducek, out_grid_desc_reducem); - } - - using ABGridDescs = decltype(MakeABGridDescriptor_A_M_K_B_M(1, 1, {}, {}, {}, {}, {}, {})); - using AGridDesc_M_K = remove_cvref_t; - using BGridDesc_M = remove_cvref_t; - - // TODO - struct Argument : public BaseArgument - { - Argument(const InDataType* p_in_dev, - OutDataType* p_out_dev, - IndexDataType* p_out_indices_dev, - ck::index_t N, - ck::index_t C, - std::vector& input_spatial_lengths, - std::vector& window_spatial_lengths, - std::vector& output_spatial_lengths, - std::vector& window_strides, - std::vector& input_left_pads, - std::vector& input_right_pads) - : p_in_dev_{p_in_dev}, - p_out_dev_{p_out_dev}, - p_out_indices_dev_{p_out_indices_dev}, - a_grid_desc_m_k_{}, - b_grid_desc_m_{} - { - const auto descs = MakeABGridDescriptor_A_M_K_B_M(N, - C, - input_spatial_lengths, - window_spatial_lengths, - output_spatial_lengths, - window_strides, - input_left_pads, - input_right_pads); - - a_grid_desc_m_k_ = descs[I0]; - b_grid_desc_m_ = descs[I1]; - - invariant_lowest_length_ = C; - reduce_lowest_length_ = window_spatial_lengths[1]; - - int32_t reduceLength = window_spatial_lengths[0] * window_spatial_lengths[1]; - - std::tie(in_element_op_, acc_element_op_) = - reduce_unary_operator::GetElementwiseOperator(reduceLength); - } - - const InDataType* p_in_dev_; - OutDataType* p_out_dev_; - IndexDataType* p_out_indices_dev_; - AGridDesc_M_K a_grid_desc_m_k_; - BGridDesc_M b_grid_desc_m_; - InElementwiseOperation in_element_op_; - AccElementwiseOperation acc_element_op_; - - // for checking vector load/store - ck::index_t invariant_lowest_length_; - ck::index_t reduce_lowest_length_; - }; - - struct Invoker : public BaseInvoker - { - float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{}) - { - using gridwise_reduce = - GridwiseReduction_mk_to_m_threadwise; - const auto kernel = - kernel_reduce_threadwise; - - ck::index_t ReduceM = arg.a_grid_desc_m_k_.GetLength(I0); - - const index_t grid_size = (ReduceM / ReduceM_BlockTileSize); - - return launch_and_time_kernel(stream_config, - kernel, - dim3(grid_size), - dim3(BlockSize), - 0, - arg.a_grid_desc_m_k_, - arg.b_grid_desc_m_, - arg.in_element_op_, - arg.acc_element_op_, - float(1), - arg.p_in_dev_, - nullptr, - float(0), - arg.p_out_dev_, - arg.p_out_indices_dev_); - } - - float Run(const BaseArgument* p_arg, - const StreamConfig& stream_config = StreamConfig{}) override - { - return Run(*dynamic_cast(p_arg), stream_config); - } - }; - - bool IsSupportedArgument(const BaseArgument* p_arg) override - { - const Argument* pArg = dynamic_cast(p_arg); - - if(pArg->invariant_lowest_length_ % InSrcOutDstVectorSize != 0) - { - return (false); - } - - return (true); - } - std::unique_ptr MakeArgumentPointer(const void* p_in_dev, void* p_out_dev, @@ -286,62 +54,57 @@ struct DevicePool2dFwd_Input_N_Hi_Wi_C_Output_N_Ho_Wo_C std::vector input_lengths, std::vector window_lengths, std::vector output_lengths, - std::vector, // Suppose tensor layout = NHWC - std::vector, // Suppose tensor layout = NHWC - std::vector, // Suppose tensor layout = NHWC + std::vector input_stride, + std::vector output_stride, + std::vector indices_stride, std::vector window_strides, + std::vector window_dilations, std::vector input_left_pads, std::vector input_right_pads, std::vector pooling_dims) override { + static constexpr index_t InOutRank = 4; + static constexpr index_t WindowRank = 2; + if(input_lengths.size() != InOutRank || window_lengths.size() != WindowRank || input_lengths.size() != InOutRank || window_strides.size() != WindowRank || - input_left_pads.size() != WindowRank || input_right_pads.size() != WindowRank) + window_dilations.size() != WindowRank || input_left_pads.size() != WindowRank || + input_right_pads.size() != WindowRank) throw std::runtime_error("dimension is incorrect"); if(pooling_dims != std::vector{2, 3}) throw std::runtime_error("pooling_dims only support {2, 3} in pool2d so far"); - index_t N = input_lengths[0]; - index_t C = input_lengths[1]; - index_t Hi = input_lengths[2]; - index_t Wi = input_lengths[3]; - index_t Ho = output_lengths[2]; - index_t Wo = output_lengths[3]; - - std::vector input_spatial_lengths = {Hi, Wi}; - std::vector output_spatial_lengths = {Ho, Wo}; - - return std::make_unique(static_cast(p_in_dev), - static_cast(p_out_dev), - static_cast(p_out_indices_dev), - N, - C, - input_spatial_lengths, - window_lengths, - output_spatial_lengths, - window_strides, - input_left_pads, - input_right_pads); - } - - std::unique_ptr MakeInvokerPointer() override - { - return std::make_unique(Invoker{}); - } - - std::string GetTypeString() const override - { - auto str = std::stringstream(); - - // clang-format off - str << "DevicePool2dFwd_Input_N_Hi_Wi_C_Output_N_Ho_Wo_C<" << BlockSize << ","; - str << "M_C" << ReduceMThreadClusterSize << "_S" << ReduceMThreadSliceSize << ","; - str << "K_C" << ReduceKThreadClusterSize << "_S" << ReduceKThreadSliceSize << ","; - str <<"InSrcOutDstVectorSize_" << InSrcOutDstVectorSize << ">"; - // clang-format on - - return str.str(); + // NCHW to NCDHW + input_lengths.insert(input_lengths.begin() + 2, 1); + output_lengths.insert(output_lengths.begin() + 2, 1); + input_stride.insert(input_stride.begin() + 2, 0); + output_stride.insert(output_stride.begin() + 2, 0); + indices_stride.insert(indices_stride.begin() + 2, 0); + + // YX to ZYX + window_lengths.insert(window_lengths.begin(), 1); + window_strides.insert(window_strides.begin(), 0); + window_dilations.insert(window_dilations.begin(), 0); + input_left_pads.insert(input_left_pads.begin(), 0); + input_right_pads.insert(input_right_pads.begin(), 0); + + pooling_dims = {2, 3, 4}; + + return DevicePool3D::MakeArgumentPointer(p_in_dev, + p_out_dev, + p_out_indices_dev, + input_lengths, + window_lengths, + output_lengths, + input_stride, + output_stride, + indices_stride, + window_strides, + window_dilations, + input_left_pads, + input_right_pads, + pooling_dims); } }; diff --git a/include/ck/tensor_operation/gpu/device/impl/device_pool3d_fwd_ndhwc_ndhwc.hpp b/include/ck/tensor_operation/gpu/device/impl/device_pool3d_fwd_ndhwc_ndhwc.hpp index d330fda8ce4ea8dd290f3a393dd4004ae55d9cc2..384805a0a3de7f8cb083daaf1588e8630d1ad6a9 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_pool3d_fwd_ndhwc_ndhwc.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_pool3d_fwd_ndhwc_ndhwc.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #pragma once @@ -8,8 +8,10 @@ #include "ck/tensor_description/tensor_descriptor.hpp" #include "ck/tensor_description/tensor_descriptor_helper.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" #include "ck/tensor_operation/gpu/device/reduction_operator_mapping.hpp" #include "ck/tensor_operation/gpu/device/device_pool_fwd.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_reduce_common.hpp" #include "ck/tensor_operation/gpu/grid/gridwise_2d_reduction_threadwise.hpp" #include "ck/host_utility/device_prop.hpp" #include "ck/host_utility/kernel_launch.hpp" @@ -30,8 +32,15 @@ template -struct DevicePool3dFwd_Input_N_Di_Hi_Wi_C_Output_N_Do_Ho_Wo_C - : public DevicePoolFwd<5, 3, InDataType, OutDataType, IndexDataType, ReduceOpId, OutputIndex> +struct DevicePool3dFwd_NDHWC_NDHWC : public DevicePoolFwd<5, + 3, + InDataType, + OutDataType, + IndexDataType, + tensor_layout::convolution::NDHWC, + tensor_layout::convolution::NDHWC, + ReduceOpId, + OutputIndex> { static constexpr auto I0 = Number<0>{}; static constexpr auto I1 = Number<1>{}; @@ -51,45 +60,48 @@ struct DevicePool3dFwd_Input_N_Di_Hi_Wi_C_Output_N_Do_Ho_Wo_C using AccElementwiseOperation = typename reduce_unary_operator::AccElementwiseOperation; - // for NDHWC, the dim C is the vector Dim for both input and output in memory, which is not - // reduced. - static constexpr index_t InSrcOutDstVectorDim = 0; - static constexpr ck::index_t M_BlockTileSize = MThreadClusterSize * MThreadSliceSize; static constexpr ck::index_t K_BlockTileSize = KThreadClusterSize * KThreadSliceSize; - static auto MakeABGridDescriptor_A_M_K_B_M(ck::index_t N, - ck::index_t C, - std::vector input_spatial_lengths, - std::vector window_spatial_lengths, - std::vector output_spatial_lengths, - std::vector window_strides, - std::vector input_left_pads, - std::vector input_right_pads) + static auto MakeABGridDescriptor_A_M_K_B_M(std::vector input_ncdhw_lengths, + std::vector output_ncdhw_lengths, + std::vector input_ncdhw_stride, + std::vector output_ncdhw_stride, + std::vector window_spatial_zyx_lengths, + std::vector window_zyx_strides, + std::vector window_zyx_dilations, + std::vector input_left_dhw_pads, + std::vector input_right_dhw_pads) { - const index_t Di = input_spatial_lengths[0]; - const index_t Hi = input_spatial_lengths[1]; - const index_t Wi = input_spatial_lengths[2]; + const index_t N = input_ncdhw_lengths[0]; + const index_t C = input_ncdhw_lengths[1]; + const index_t Di = input_ncdhw_lengths[2]; + const index_t Hi = input_ncdhw_lengths[3]; + const index_t Wi = input_ncdhw_lengths[4]; + + const index_t Do = output_ncdhw_lengths[2]; + const index_t Ho = output_ncdhw_lengths[3]; + const index_t Wo = output_ncdhw_lengths[4]; - const index_t Do = output_spatial_lengths[0]; - const index_t Ho = output_spatial_lengths[1]; - const index_t Wo = output_spatial_lengths[2]; + const index_t Z = window_spatial_zyx_lengths[0]; + const index_t Y = window_spatial_zyx_lengths[1]; + const index_t X = window_spatial_zyx_lengths[2]; - const index_t Z = window_spatial_lengths[0]; - const index_t Y = window_spatial_lengths[1]; - const index_t X = window_spatial_lengths[2]; + const index_t WindowStrideD = window_zyx_strides[0]; + const index_t WindowStrideH = window_zyx_strides[1]; + const index_t WindowStrideW = window_zyx_strides[2]; - const index_t ConvStrideD = window_strides[0]; - const index_t ConvStrideH = window_strides[1]; - const index_t ConvStrideW = window_strides[2]; + const index_t WindowDilationD = window_zyx_dilations[0]; + const index_t WindowDilationH = window_zyx_dilations[1]; + const index_t WindowDilationW = window_zyx_dilations[2]; - const index_t InLeftPadD = input_left_pads[0]; - const index_t InLeftPadH = input_left_pads[1]; - const index_t InLeftPadW = input_left_pads[2]; + const index_t InLeftPadD = input_left_dhw_pads[0]; + const index_t InLeftPadH = input_left_dhw_pads[1]; + const index_t InLeftPadW = input_left_dhw_pads[2]; - const index_t InRightPadD = input_right_pads[0]; - const index_t InRightPadH = input_right_pads[1]; - const index_t InRightPadW = input_right_pads[2]; + const index_t InRightPadD = input_right_dhw_pads[0]; + const index_t InRightPadH = input_right_dhw_pads[1]; + const index_t InRightPadW = input_right_dhw_pads[2]; const index_t MRaw = N * Do * Ho * Wo * C; const index_t MPad = math::integer_least_multiple(MRaw, M_BlockTileSize) - MRaw; @@ -98,8 +110,15 @@ struct DevicePool3dFwd_Input_N_Di_Hi_Wi_C_Output_N_Do_Ho_Wo_C const index_t KPad = math::integer_least_multiple(KRaw, K_BlockTileSize) - KRaw; // A[ReduceM, ReduceK] - const auto in_grid_desc_n_di_hi_wi_c = - make_naive_tensor_descriptor_packed(make_tuple(N, Di, Hi, Wi, C)); + const index_t Ni_stride = input_ncdhw_stride[0]; + const index_t Ci_stride = input_ncdhw_stride[1]; + const index_t Di_stride = input_ncdhw_stride[2]; + const index_t Hi_stride = input_ncdhw_stride[3]; + const index_t Wi_stride = input_ncdhw_stride[4]; + + const auto in_grid_desc_n_di_hi_wi_c = make_naive_tensor_descriptor( + make_tuple(N, Di, Hi, Wi, C), + make_tuple(Ni_stride, Di_stride, Hi_stride, Wi_stride, Ci_stride)); const auto in_grid_desc_n_dip_hip_wip_c = transform_tensor_descriptor( in_grid_desc_n_di_hi_wi_c, @@ -113,11 +132,12 @@ struct DevicePool3dFwd_Input_N_Di_Hi_Wi_C_Output_N_Do_Ho_Wo_C const auto in_grid_desc_n_z_do_y_ho_x_wo_c = transform_tensor_descriptor( in_grid_desc_n_dip_hip_wip_c, - make_tuple(make_pass_through_transform(N), - make_embed_transform(make_tuple(Z, Do), make_tuple(I1, ConvStrideD)), - make_embed_transform(make_tuple(Y, Ho), make_tuple(I1, ConvStrideH)), - make_embed_transform(make_tuple(X, Wo), make_tuple(I1, ConvStrideW)), - make_pass_through_transform(C)), + make_tuple( + make_pass_through_transform(N), + make_embed_transform(make_tuple(Z, Do), make_tuple(WindowDilationD, WindowStrideD)), + make_embed_transform(make_tuple(Y, Ho), make_tuple(WindowDilationH, WindowStrideH)), + make_embed_transform(make_tuple(X, Wo), make_tuple(WindowDilationW, WindowStrideW)), + make_pass_through_transform(C)), make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}, Sequence<4>{}), make_tuple(Sequence<0>{}, Sequence<1, 2>{}, @@ -139,8 +159,21 @@ struct DevicePool3dFwd_Input_N_Di_Hi_Wi_C_Output_N_Do_Ho_Wo_C make_tuple(Sequence<0>{}, Sequence<1>{})); // B[ReduceM] - const auto out_grid_desc_reducemraw = - make_naive_tensor_descriptor_packed(make_tuple(N * Do * Ho * Wo * C)); + const index_t No_stride = output_ncdhw_stride[0]; + const index_t Co_stride = output_ncdhw_stride[1]; + const index_t Do_stride = output_ncdhw_stride[2]; + const index_t Ho_stride = output_ncdhw_stride[3]; + const index_t Wo_stride = output_ncdhw_stride[4]; + + const auto out_grid_desc_n_do_ho_wo_c = make_naive_tensor_descriptor( + make_tuple(N, Di, Hi, Wi, C), + make_tuple(No_stride, Do_stride, Ho_stride, Wo_stride, Co_stride)); + + const auto out_grid_desc_reducemraw = transform_tensor_descriptor( + out_grid_desc_n_do_ho_wo_c, + make_tuple(make_merge_transform(make_tuple(N, Do, Ho, Wo, C))), + make_tuple(Sequence<0, 1, 2, 3, 4>{}), + make_tuple(Sequence<0>{})); const auto out_grid_desc_reducem = transform_tensor_descriptor(out_grid_desc_reducemraw, @@ -151,7 +184,9 @@ struct DevicePool3dFwd_Input_N_Di_Hi_Wi_C_Output_N_Do_Ho_Wo_C return make_tuple(in_grid_desc_reducem_reducek, out_grid_desc_reducem); } - using ABGridDescs = decltype(MakeABGridDescriptor_A_M_K_B_M(1, 1, {}, {}, {}, {}, {}, {})); + using ABGridDescs = + decltype(MakeABGridDescriptor_A_M_K_B_M({}, {}, {}, {}, {}, {}, {}, {}, {})); + using AGridDesc_M_K = remove_cvref_t; using BGridDesc_M = remove_cvref_t; @@ -160,36 +195,41 @@ struct DevicePool3dFwd_Input_N_Di_Hi_Wi_C_Output_N_Do_Ho_Wo_C Argument(const InDataType* p_in_dev, OutDataType* p_out_dev, IndexDataType* p_out_indices_dev, - ck::index_t N, - ck::index_t C, - std::vector& input_spatial_lengths, - std::vector& window_spatial_lengths, - std::vector& output_spatial_lengths, - std::vector& window_strides, - std::vector& input_left_pads, - std::vector& input_right_pads) + std::vector& input_ncdhw_lengths, + std::vector& output_ncdhw_lengths, + std::vector& input_ncdhw_stride, + std::vector& output_ncdhw_stride, + std::vector&, // indices_ncdhw_stride + std::vector& window_spatial_zyx_lengths, + std::vector& window_zyx_strides, + std::vector& window_zyx_dilations, + std::vector& input_left_dhw_pads, + std::vector& input_right_dhw_pads) : p_in_dev_{p_in_dev}, p_out_dev_{p_out_dev}, p_out_indices_dev_{p_out_indices_dev}, a_grid_desc_m_k_{}, - b_grid_desc_m_{} + b_grid_desc_m_{}, + input_ncdhw_lengths_{input_ncdhw_lengths}, + output_ncdhw_lengths_{output_ncdhw_lengths}, + input_ncdhw_stride_{input_ncdhw_stride}, + output_ncdhw_stride_{output_ncdhw_stride} { - const auto descs = MakeABGridDescriptor_A_M_K_B_M(N, - C, - input_spatial_lengths, - window_spatial_lengths, - output_spatial_lengths, - window_strides, - input_left_pads, - input_right_pads); + const auto descs = MakeABGridDescriptor_A_M_K_B_M(input_ncdhw_lengths, + output_ncdhw_lengths, + input_ncdhw_stride, + output_ncdhw_stride, + window_spatial_zyx_lengths, + window_zyx_strides, + window_zyx_dilations, + input_left_dhw_pads, + input_right_dhw_pads); a_grid_desc_m_k_ = descs[I0]; b_grid_desc_m_ = descs[I1]; - invariant_lowest_length_ = C; - - int32_t reduceLength = - window_spatial_lengths[0] * window_spatial_lengths[1] * window_spatial_lengths[2]; + int32_t reduceLength = window_spatial_zyx_lengths[0] * window_spatial_zyx_lengths[1] * + window_spatial_zyx_lengths[2]; std::tie(in_element_op_, acc_element_op_) = reduce_unary_operator::GetElementwiseOperator(reduceLength); @@ -200,17 +240,25 @@ struct DevicePool3dFwd_Input_N_Di_Hi_Wi_C_Output_N_Do_Ho_Wo_C IndexDataType* p_out_indices_dev_; AGridDesc_M_K a_grid_desc_m_k_; BGridDesc_M b_grid_desc_m_; + InElementwiseOperation in_element_op_; AccElementwiseOperation acc_element_op_; // for checking vector load/store - ck::index_t invariant_lowest_length_; + std::vector input_ncdhw_lengths_; + std::vector output_ncdhw_lengths_; + std::vector input_ncdhw_stride_; + std::vector output_ncdhw_stride_; }; struct Invoker : public BaseInvoker { float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{}) { + // for NDHWC, the dim C is the fastest dimension, and is not reduced. + // Hence, it is in M dimension for reduction kernel. + static constexpr index_t InSrcOutDstVectorDim = 0; // 0: M, 1: K + using gridwise_reduce = GridwiseReduction_mk_to_m_threadwise(p_arg); - if(pArg->invariant_lowest_length_ % InSrcOutDstVectorSize != 0) - { + // C should be fastest dimension + if(pArg->input_ncdhw_stride_[1] != 1) return false; + + for(int i = 0; i < InOutRank; ++i) + { + if(pArg->input_ncdhw_stride_[i] == 1 && + pArg->input_ncdhw_lengths_[i] % InSrcOutDstVectorSize != 0) + return false; + + if(pArg->output_ncdhw_stride_[i] == 1 && + pArg->output_ncdhw_lengths_[i] % InSrcOutDstVectorSize != 0) + return false; } return true; } - std::unique_ptr + virtual std::unique_ptr MakeArgumentPointer(const void* p_in_dev, void* p_out_dev, void* p_out_indices_dev, - std::vector input_lengths, - std::vector window_lengths, - std::vector output_lengths, - std::vector, // Suppose tensor layout = NDHWC - std::vector, // Suppose tensor layout = NDHWC - std::vector, // Suppose tensor layout = NDHWC - std::vector window_strides, - std::vector input_left_pads, - std::vector input_right_pads, + std::vector input_ncdhw_lengths, + std::vector window_zyx_lengths, + std::vector output_ncdhw_lengths, + std::vector input_ncdhw_stride, + std::vector output_ncdhw_stride, + std::vector indices_ncdhw_stride, + std::vector window_zyx_strides, + std::vector window_zyx_dilations, + std::vector input_left_dhw_pads, + std::vector input_right_dhw_pads, std::vector pooling_dims) override { - if(input_lengths.size() != InOutRank || window_lengths.size() != WindowRank || - input_lengths.size() != InOutRank || window_strides.size() != WindowRank || - input_left_pads.size() != WindowRank || input_right_pads.size() != WindowRank) + if(input_ncdhw_lengths.size() != InOutRank || window_zyx_lengths.size() != WindowRank || + input_ncdhw_lengths.size() != InOutRank || window_zyx_strides.size() != WindowRank || + window_zyx_dilations.size() != WindowRank || input_left_dhw_pads.size() != WindowRank || + input_right_dhw_pads.size() != WindowRank) throw std::runtime_error("dimension is incorrect"); if(pooling_dims != std::vector{2, 3, 4}) throw std::runtime_error("pooling_dims only support {2, 3, 4} in pool3d so far"); - index_t N = input_lengths[0]; - index_t C = input_lengths[1]; - index_t Di = input_lengths[2]; - index_t Hi = input_lengths[3]; - index_t Wi = input_lengths[4]; - index_t Do = output_lengths[2]; - index_t Ho = output_lengths[3]; - index_t Wo = output_lengths[4]; - - std::vector input_spatial_lengths = {Di, Hi, Wi}; - std::vector output_spatial_lengths = {Do, Ho, Wo}; + if(output_ncdhw_stride != indices_ncdhw_stride) + throw std::runtime_error( + "output_ncdhw_stride need to be equal to indices_ncdhw_stride for now"); return std::make_unique(static_cast(p_in_dev), static_cast(p_out_dev), static_cast(p_out_indices_dev), - N, - C, - input_spatial_lengths, - window_lengths, - output_spatial_lengths, - window_strides, - input_left_pads, - input_right_pads); + input_ncdhw_lengths, + output_ncdhw_lengths, + input_ncdhw_stride, + output_ncdhw_stride, + indices_ncdhw_stride, + window_zyx_lengths, + window_zyx_strides, + window_zyx_dilations, + input_left_dhw_pads, + input_right_dhw_pads); } std::unique_ptr MakeInvokerPointer() override @@ -342,7 +396,7 @@ struct DevicePool3dFwd_Input_N_Di_Hi_Wi_C_Output_N_Do_Ho_Wo_C auto str = std::stringstream(); // clang-format off - str << "DevicePool3dFwd_Input_N_Di_Hi_Wi_C_Output_N_Do_Ho_Wo_C<" << BlockSize << ","; + str << "DevicePool3dFwd_NDHWC_NDHWC<" << BlockSize << ","; str << "M_C" << MThreadClusterSize << "_S" << MThreadSliceSize << ","; str << "K_C" << KThreadClusterSize << "_S" << KThreadSliceSize << ","; str <<"InSrcOutDstVectorSize_" << InSrcOutDstVectorSize << ">"; diff --git a/include/ck/tensor_operation/gpu/device/impl/device_put_element_impl.hpp b/include/ck/tensor_operation/gpu/device/impl/device_put_element_impl.hpp new file mode 100644 index 0000000000000000000000000000000000000000..7334da0e35af5ce92665decc4897ebe48393b9c3 --- /dev/null +++ b/include/ck/tensor_operation/gpu/device/impl/device_put_element_impl.hpp @@ -0,0 +1,155 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include +#include + +#include "ck/tensor_description/tensor_descriptor.hpp" +#include "ck/tensor_description/tensor_descriptor_helper.hpp" +#include "ck/tensor_operation/gpu/device/device_put_element.hpp" +#include "ck/tensor_operation/gpu/grid/gridwise_put_element_1d.hpp" +#include "ck/host_utility/device_prop.hpp" +#include "ck/host_utility/kernel_launch.hpp" +#include "ck/host_utility/stream_utility.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { + +// output[indices] = input +template +struct DevicePutElementImpl + : public DevicePutElement +{ + template + static auto PadDescriptor_M_1d(Desc_M desc_m, index_t gridSize, index_t blockSize) + { + constexpr auto I0 = Number<0>{}; + + const auto m = desc_m.GetLength(I0); + const index_t loop_step = gridSize * blockSize * InVectorSize; + const auto pad = math::integer_least_multiple(m, loop_step) - m; + const auto desc_m_pad = + transform_tensor_descriptor(desc_m, + make_tuple(make_right_pad_transform(m, pad)), + make_tuple(Sequence<0>{}), + make_tuple(Sequence<0>{})); + return desc_m_pad; + } + + static auto MakeDescriptor_M(index_t length, index_t gridSize, index_t blockSize) + { + const auto desc_m = make_naive_tensor_descriptor_packed(make_tuple(length)); + return PadDescriptor_M_1d(desc_m, gridSize, blockSize); + } + + using InGrid1dDesc = decltype(MakeDescriptor_M(1, 1, 1)); + + using GridwisePutElement = GridwisePutElement_1D; + + struct Argument : public BaseArgument + { + Argument(const InDataType* p_input, + const IndexDataType* p_indices, + OutDataType* p_output, + index_t input_length, + ElementwiseOperation elementwise_op) + : p_input_{p_input}, + p_indices_{p_indices}, + p_output_{p_output}, + input_length_raw_{input_length}, + elementwise_op_{elementwise_op}, + blockSize_{256} + { + } + + const InDataType* p_input_; + const IndexDataType* p_indices_; + OutDataType* p_output_; + index_t input_length_raw_; + ElementwiseOperation elementwise_op_; + index_t blockSize_; + }; + + struct Invoker : public BaseInvoker + { + float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{}) + { + index_t gridSize = getAvailableComputeUnitCount(stream_config); + InGrid1dDesc in_grid_desc = + MakeDescriptor_M(arg.input_length_raw_, gridSize, arg.blockSize_); + + const auto kernel = kernel_put_element_1d; + + float elapsed_time = launch_and_time_kernel(stream_config, + kernel, + dim3(gridSize), + dim3(arg.blockSize_), + 0, + in_grid_desc, + arg.p_input_, + arg.p_indices_, + arg.p_output_, + arg.elementwise_op_); + return elapsed_time; + } + + float Run(const BaseArgument* p_arg, + const StreamConfig& stream_config = StreamConfig{}) override + { + return Run(*dynamic_cast(p_arg), stream_config); + } + }; + + bool IsSupportedArgument(const BaseArgument* p_arg) override + { + const Argument* pArg = dynamic_cast(p_arg); + + if(pArg->input_length_raw_ % InVectorSize != 0) + { + return false; + } + return true; + } + + std::unique_ptr MakeArgumentPointer(const void* p_input, + const void* p_indices, + void* p_output, + index_t input_length, + index_t, + ElementwiseOperation elementwise_op) override + { + return std::make_unique(static_cast(p_input), + static_cast(p_indices), + static_cast(p_output), + input_length, + elementwise_op); + } + + std::unique_ptr MakeInvokerPointer() override + { + return std::make_unique(Invoker{}); + } +}; + +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/include/ck/tensor_operation/gpu/device/impl/device_reduce_common.hpp b/include/ck/tensor_operation/gpu/device/impl/device_reduce_common.hpp index 5dc051be3cb7f23fef076921e826a7e55e017d76..2481c5c76971db6bb2ddc390510089af5f656ae6 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_reduce_common.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_reduce_common.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/include/ck/tensor_operation/gpu/device/impl/device_reduce_multiblock.hpp b/include/ck/tensor_operation/gpu/device/impl/device_reduce_multiblock.hpp index c7868537fe8220c640ba24d346eea5dc219490fa..bf3deeb57acbdff95995f5da096ab574140bcb42 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_reduce_multiblock.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_reduce_multiblock.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/include/ck/tensor_operation/gpu/device/impl/device_reduce_threadwise.hpp b/include/ck/tensor_operation/gpu/device/impl/device_reduce_threadwise.hpp index aa255da645e53e0f0a6bf9578d0c64b538b30c8e..6c5895b010659cafb98e4a78bc6587d0222e02bd 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_reduce_threadwise.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_reduce_threadwise.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/include/ck/tensor_operation/gpu/device/impl/device_softmax_impl.hpp b/include/ck/tensor_operation/gpu/device/impl/device_softmax_impl.hpp index ed96b7340cfc6c78468d8fb9bfb2df21243918ef..8eff9d241509bdc4b1cc3993a39f0e3c7fa6828d 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_softmax_impl.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_softmax_impl.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #pragma once @@ -38,16 +38,9 @@ struct DeviceSoftmaxImpl : public DeviceSoftmax + Rank, + NumReduceDim> { - static constexpr index_t kRank = Rank; - static constexpr index_t kNumReduceDim = NumReduceDim; - static constexpr index_t kNumInvariantDim = Rank - NumReduceDim; - - virtual index_t GetRank() const override { return kRank; } - - virtual index_t GetNumReduceDim() const override { return kNumReduceDim; } - static constexpr index_t NumInvariantDim = Rank - NumReduceDim; static constexpr index_t NumSrcDim = Rank; @@ -287,13 +280,13 @@ struct DeviceSoftmaxImpl : public DeviceSoftmax 0 && arg.invariant_lowest_length_ % OutDstVectorSize != 0) + if(NumInvariantDim > 0 && arg.invariant_lowest_length_ % OutDstVectorSize != 0) { return false; } diff --git a/include/ck/tensor_operation/gpu/device/impl/device_sparse_embeddings_forward_layernorm.hpp b/include/ck/tensor_operation/gpu/device/impl/device_sparse_embeddings_forward_layernorm.hpp index 2f29224a754057430cd68bb551807226d40a6cb5..7a62ec04650a94d1b9f458e425ddf3b14881091d 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_sparse_embeddings_forward_layernorm.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_sparse_embeddings_forward_layernorm.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/include/ck/tensor_operation/gpu/device/device_splitk_contraction_multiple_d_xdl_cshuffle.hpp b/include/ck/tensor_operation/gpu/device/impl/device_splitk_contraction_multiple_d_xdl_cshuffle.hpp similarity index 98% rename from include/ck/tensor_operation/gpu/device/device_splitk_contraction_multiple_d_xdl_cshuffle.hpp rename to include/ck/tensor_operation/gpu/device/impl/device_splitk_contraction_multiple_d_xdl_cshuffle.hpp index 70990e795cc7ea594a94aee24737fc3dfc6ba664..cfab7d29c9002f4fc61b154e127c4f09ab9005c9 100644 --- a/include/ck/tensor_operation/gpu/device/device_splitk_contraction_multiple_d_xdl_cshuffle.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_splitk_contraction_multiple_d_xdl_cshuffle.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #pragma once @@ -57,7 +57,7 @@ __global__ void const Block2ETileMap block_2_etile_map) { #if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \ - defined(__gfx940__)) + defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__)) __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()]; const index_t num_blocks_per_batch = @@ -617,10 +617,12 @@ struct DeviceSplitKContractionMultipleD_Xdl_CShuffle CDEBlockTransferScalarPerVector_NPerBlock, LoopSched>; - using AGridDesc_AKB_AK0_M_AK1 = remove_cvref_t; - using BGridDesc_BKB_BK0_N_BK1 = remove_cvref_t; + using AGridDesc_AKB_AK0_M_AK1 = + remove_cvref_t; + using BGridDesc_BKB_BK0_N_BK1 = + remove_cvref_t; using Block2ETileMap = typename GridwiseGemm::DefaultBlock2ETileMap; @@ -886,11 +888,12 @@ struct DeviceSplitKContractionMultipleD_Xdl_CShuffle typename GridwiseGemmAtomicAdd::DefaultBlock2ETileMap, has_main_loop>; - hipGetErrorString(hipMemset( + hipGetErrorString(hipMemsetAsync( arg.p_e_grid_, 0, arg.e_grid_desc_mblock_mperblock_nblock_nperblock_.GetElementSpaceSize() * - sizeof(EDataType))); + sizeof(EDataType), + stream_config.stream_id_)); return launch_and_time_kernel(stream_config, kernel, @@ -939,8 +942,7 @@ struct DeviceSplitKContractionMultipleD_Xdl_CShuffle static bool IsSupportedArgument(const Argument& arg) { - if(!(ck::get_device_name() == "gfx908" || ck::get_device_name() == "gfx90a" || - ck::get_device_name() == "gfx940")) + if(!ck::is_xdl_supported()) { return false; } diff --git a/include/ck/tensor_operation/gpu/device/masking_specialization.hpp b/include/ck/tensor_operation/gpu/device/masking_specialization.hpp index ea0f5897a757c408047b20e0cd4285e13f67c4f5..d6d6f74abdb5a63ef8a841d033c9f47ec59bd07e 100644 --- a/include/ck/tensor_operation/gpu/device/masking_specialization.hpp +++ b/include/ck/tensor_operation/gpu/device/masking_specialization.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/include/ck/tensor_operation/gpu/device/matrix_padder.hpp b/include/ck/tensor_operation/gpu/device/matrix_padder.hpp index 70e61bc7728f9638e5bef874ab0d5caad2db5028..c66d2fc516babe3aa5a16de053d9ada89eabf8ef 100644 --- a/include/ck/tensor_operation/gpu/device/matrix_padder.hpp +++ b/include/ck/tensor_operation/gpu/device/matrix_padder.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/include/ck/tensor_operation/gpu/device/reduction_operator_mapping.hpp b/include/ck/tensor_operation/gpu/device/reduction_operator_mapping.hpp index d35318357a9bcc6fb087c92e27dcf435ff60796b..5351d4ef2439307ad5623f4172217937812ab29a 100644 --- a/include/ck/tensor_operation/gpu/device/reduction_operator_mapping.hpp +++ b/include/ck/tensor_operation/gpu/device/reduction_operator_mapping.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/include/ck/tensor_operation/gpu/device/tensor_layout.hpp b/include/ck/tensor_operation/gpu/device/tensor_layout.hpp index b44427411f9a28511fd16f6b47ee37fa84a0e2d3..b2d141fd61aa4550ef8b572e548688120634b8b9 100644 --- a/include/ck/tensor_operation/gpu/device/tensor_layout.hpp +++ b/include/ck/tensor_operation/gpu/device/tensor_layout.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/include/ck/tensor_operation/gpu/device/tensor_specialization.hpp b/include/ck/tensor_operation/gpu/device/tensor_specialization.hpp index 0ec0df2c9bbbb2aac987d681a36d5b00a8349c97..713fc93ebb4eec9ce0301affd667656bf0b6c999 100644 --- a/include/ck/tensor_operation/gpu/device/tensor_specialization.hpp +++ b/include/ck/tensor_operation/gpu/device/tensor_specialization.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/include/ck/tensor_operation/gpu/device/welford_helper.hpp b/include/ck/tensor_operation/gpu/device/welford_helper.hpp index 6c909b767d4602914660d045b9aaf5a2307ea57b..d7772d8764ff3cc66572775f750997fae944800e 100644 --- a/include/ck/tensor_operation/gpu/device/welford_helper.hpp +++ b/include/ck/tensor_operation/gpu/device/welford_helper.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/include/ck/tensor_operation/gpu/element/binary_element_wise_operation.hpp b/include/ck/tensor_operation/gpu/element/binary_element_wise_operation.hpp index 0d78e135893df64402ff56552190874c7b909d8a..9fe0931cbae44250295720f090bda56588d3fb15 100644 --- a/include/ck/tensor_operation/gpu/element/binary_element_wise_operation.hpp +++ b/include/ck/tensor_operation/gpu/element/binary_element_wise_operation.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #pragma once @@ -36,6 +36,13 @@ struct Add y = x0 + type_convert(x1); }; + template <> + __host__ __device__ constexpr void + operator()(half_t& y, const float& x0, const float& x1) const + { + y = type_convert(x0 + x1); + }; + template <> __host__ __device__ constexpr void operator()(half_t& y, const float& x0, const half_t& x1) const diff --git a/include/ck/tensor_operation/gpu/element/element_wise_operation.hpp b/include/ck/tensor_operation/gpu/element/element_wise_operation.hpp index ceb2b665b91529f0a97cfb1209b2de40744ee466..9f5ed6adea1915a5a1ea2f67ba65b7641db22855 100644 --- a/include/ck/tensor_operation/gpu/element/element_wise_operation.hpp +++ b/include/ck/tensor_operation/gpu/element/element_wise_operation.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #pragma once @@ -195,6 +195,51 @@ struct AddMultiply } }; +// C = A * B +// E = C x D0 + D1 +struct MultiplyAdd +{ + template + __host__ __device__ void operator()(E& e, const C& c, const D0& d0, const D1& d1) const; + + template <> + __host__ __device__ void operator()(half_t& e, + const half_t& c, + const half_t& d0, + const half_t& d1) const + { + const half_t y = (c * d0) + d1; + e = y; + } + template <> + __host__ __device__ void operator()(half_t& e, + const float& c, + const half_t& d0, + const half_t& d1) const + { + const half_t y = type_convert(c) * d0 + d1; + e = y; + } + template <> + __host__ __device__ void operator()(float& e, + const float& c, + const half_t& d0, + const half_t& d1) const + { + const float y = c * d0 + d1; + e = y; + } + template <> + __host__ __device__ void operator()(half_t& e, + const float& c, + const float& d0, + const float& d1) const + { + const float y = c * d0 + d1; + e = y; + } +}; + // E = FastGelu(C + D0 + D1) struct AddAddFastGelu { diff --git a/include/ck/tensor_operation/gpu/element/unary_element_wise_operation.hpp b/include/ck/tensor_operation/gpu/element/unary_element_wise_operation.hpp index ef250b8bfd1cbefabcdc3b95a0d94e0c6168c762..905908a1c3b0d700b43ec8dc209b984a71e7baa3 100644 --- a/include/ck/tensor_operation/gpu/element/unary_element_wise_operation.hpp +++ b/include/ck/tensor_operation/gpu/element/unary_element_wise_operation.hpp @@ -1,11 +1,12 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #pragma once #include "ck/utility/data_type.hpp" #include "ck/utility/math.hpp" #include "ck/utility/math_v2.hpp" +#include "ck/utility/type_convert.hpp" namespace ck { namespace tensor_operation { @@ -38,6 +39,12 @@ struct PassThrough y = x; } + template <> + __host__ __device__ void operator()(half_t& y, const float& x) const + { + y = type_convert(x); + } + template <> __host__ __device__ void operator()(bhalf_t& y, const bhalf_t& x) const { @@ -81,6 +88,36 @@ struct PassThrough y = x; } #endif + + template <> + __host__ __device__ void operator()(f8_t& y, const f8_t& x) const + { + y = x; + } + + template <> + __host__ __device__ void operator()(float& y, const f8_t& x) const + { + y = type_convert(x); + } + + template <> + __host__ __device__ void operator()(f8_t& y, const float& x) const + { + y = type_convert(x); + } + + template <> + __host__ __device__ void operator()(half_t& y, const f8_t& x) const + { + y = type_convert(x); + } + + template <> + __host__ __device__ void operator()(f8_t& y, const half_t& x) const + { + y = type_convert(x); + } }; struct UnaryConvert @@ -109,6 +146,23 @@ struct ConvertBF16RTN } }; +struct ConvertF8SR +{ + // convert to fp8 using stochastic rounding (SR) + template + __host__ __device__ void operator()(Y& y, const X& x) const + { + // check Y datatype + static_assert(is_same::value, "Data type is not supported by this operation!"); + + // check X datatype + static_assert(is_same::value || is_same::value, + "Data type is not supported by this operation!"); + + y = f8_convert_sr(x); + } +}; + struct Scale { __host__ __device__ Scale(float scale) : scale_(scale) {} diff --git a/include/ck/tensor_operation/gpu/grid/batchnorm_multiblock/gridwise_multiblock_batchnorm_forward.hpp b/include/ck/tensor_operation/gpu/grid/batchnorm_multiblock/gridwise_multiblock_batchnorm_forward.hpp new file mode 100644 index 0000000000000000000000000000000000000000..47573107cf4aa9296e3724e1b470223e623c4b35 --- /dev/null +++ b/include/ck/tensor_operation/gpu/grid/batchnorm_multiblock/gridwise_multiblock_batchnorm_forward.hpp @@ -0,0 +1,704 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck/utility/data_type.hpp" +#include "ck/utility/math_v2.hpp" +#include "ck/tensor_operation/gpu/block/blockwise_welford.hpp" +#include "ck/tensor_operation/gpu/thread/threadwise_welford.hpp" +#include "ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" + +#include "ck/utility/workgroup_synchronization.hpp" + +namespace ck { + +template +__global__ void kernel_multiblock_batchnorm_forward( + const XYGridDesc_M_K x_grid_desc_m_k, + const XYGridDesc_M_K y_grid_desc_m_k, + const MeanVarCountGridDesc_M_G mean_var_count_grid_desc_m_g, + const MeanVarCountGridDesc_M_K mean_var_count_grid_desc_m_k, + const ScaleBiasGridDesc_M scale_grid_desc_m, + const ScaleBiasGridDesc_M bias_grid_desc_m, + const MeanVarGridDesc_M mean_var_grid_desc_m, + const GetReduceCountPerThreadFunctor get_reduce_count_per_thread, + index_t num_k_block_tile_iteration, + AccDataType epsilon, + const XDataType* const __restrict__ p_x, + MeanVarDataType* const __restrict__ p_welford_mean, + MeanVarDataType* const __restrict__ p_welford_variance, + int32_t* const __restrict__ p_welford_count, + int32_t* const __restrict__ p_control, + const ScaleDataType* const __restrict__ p_scale, + const BiasDataType* const __restrict__ p_bias, + const YElementwiseOp y_elementwise_op, + YDataType* const __restrict__ p_y, + bool updateMovingAverage, + AccDataType averageFactor, + MeanVarDataType* const __restrict__ resultRunningMean, + MeanVarDataType* const __restrict__ resultRunningVariance, + bool saveMeanInvVariance, + MeanVarDataType* const __restrict__ resultSaveMean, + MeanVarDataType* const __restrict__ resultSaveInvVariance) +{ + GridwiseMultiblockBatchNormForward_::Run(x_grid_desc_m_k, + y_grid_desc_m_k, + mean_var_count_grid_desc_m_g, + mean_var_count_grid_desc_m_k, + scale_grid_desc_m, + bias_grid_desc_m, + mean_var_grid_desc_m, + get_reduce_count_per_thread, + num_k_block_tile_iteration, + epsilon, + p_x, + p_welford_mean, + p_welford_variance, + p_welford_count, + p_control, + p_scale, + p_bias, + y_elementwise_op, + p_y, + updateMovingAverage, + averageFactor, + resultRunningMean, + resultRunningVariance, + saveMeanInvVariance, + resultSaveMean, + resultSaveInvVariance); +}; + +template +struct GridwiseMultiblockBatchNormForward +{ + static_assert((XSrcYDstVectorDim == 0 && MThreadSliceSize % XSrcVectorSize == 0) || + (XSrcYDstVectorDim == 1 && KThreadSliceSize % XSrcVectorSize == 0), + "Invalid thread slice sizes and/or vector sizes configuration, please check!"); + + static_assert((XSrcYDstVectorDim == 0 && MThreadSliceSize % YDstVectorSize == 0) || + (XSrcYDstVectorDim == 1 && KThreadSliceSize % YDstVectorSize == 0), + "Invalid thread slice sizes and/or vector sizes configuration, please check!"); + + static constexpr bool reorder_thread_cluster = (XSrcYDstVectorDim == 0); + + using ThreadClusterLengths_M_K = Sequence; + + using ThreadBufferDimAccessOrder = + typename conditional, Sequence<0, 1>>::type; + + using ThreadClusterArrangeOrder = + typename conditional, Sequence<0, 1>>::type; + + static constexpr auto thread_cluster_desc = + make_cluster_descriptor(ThreadClusterLengths_M_K{}, ThreadClusterArrangeOrder{}); + + using ThreadReduceSrcDesc_M_K = decltype(make_naive_tensor_descriptor_packed( + make_tuple(Number{}, Number{}))); + using ThreadReduceDstDesc_M = + decltype(make_naive_tensor_descriptor_packed(make_tuple(Number{}))); + + using ThreadReduceSrcDesc_M_1 = decltype(make_naive_tensor_descriptor_packed( + make_tuple(Number{}, Number<1>{}))); + + using ThreadwiseWelford1 = + ThreadwiseWelford; + + using ThreadwiseWelford2 = + ThreadwiseWelfordMerge; + + using BlockwiseWelford1 = BlockwiseWelford; + + using BlockwiseWelford2 = BlockwiseWelford; + + using PassThroughOp = tensor_operation::element_wise::PassThrough; + + static constexpr auto I0 = Number<0>{}; + static constexpr auto I1 = Number<1>{}; + + static constexpr index_t M_BlockTileSize = MThreadClusterSize * MThreadSliceSize; + static constexpr index_t K_BlockTileSize = KThreadClusterSize * KThreadSliceSize; + + __device__ static void Run(const XYGridDesc_M_K& x_grid_desc_m_k, + const XYGridDesc_M_K& y_grid_desc_m_k, + const MeanVarCountGridDesc_M_G& mean_var_count_grid_desc_m_g, + const MeanVarCountGridDesc_M_K& mean_var_count_grid_desc_m_k, + const ScaleBiasGridDesc_M& scale_grid_desc_m, + const ScaleBiasGridDesc_M& bias_grid_desc_m, + const MeanVarGridDesc_M& mean_var_grid_desc_m, + const GetReduceCountPerThreadFunctor& get_reduce_count_per_thread, + index_t num_k_block_tile_iteration, + AccDataType epsilon, + const XDataType* const __restrict__ p_x, + MeanVarDataType* const __restrict__ p_welford_mean, + MeanVarDataType* const __restrict__ p_welford_variance, + int32_t* const __restrict__ p_welford_count, + int32_t* const __restrict__ p_control, + const ScaleDataType* const __restrict__ p_scale, + const BiasDataType* const __restrict__ p_bias, + const YElementwiseOp y_elementwise_op, + YDataType* const __restrict__ p_y, + bool updateMovingAverage, + AccDataType averageFactor, + MeanVarDataType* const __restrict__ resultRunningMean, + MeanVarDataType* const __restrict__ resultRunningVariance, + bool saveMeanInvVariance, + MeanVarDataType* const __restrict__ resultSaveMean, + MeanVarDataType* const __restrict__ resultSaveInvVariance) + { + using ck::math::sqrt; + + const index_t blkgroup_size = mean_var_count_grid_desc_m_g.GetLength(I1); + + const index_t thread_local_id = get_thread_local_1d_id(); + const index_t block_global_id = get_block_1d_id(); + const index_t blkgroup_id = block_global_id / blkgroup_size; + const index_t block_local_id = block_global_id % blkgroup_size; + + if(block_local_id == 0) + gms_init(BlockSize / warpSize * blkgroup_size, &p_control[blkgroup_id * 2]); + + const auto thread_cluster_idx = + thread_cluster_desc.CalculateBottomIndex(make_multi_index(thread_local_id)); + + const auto thread_m_cluster_id = thread_cluster_idx[I0]; + const auto thread_k_cluster_id = thread_cluster_idx[I1]; + + using ThreadBufferLengths_M_K = Sequence; + using ThreadBufferLengths_M = Sequence; + using ThreadBufferLengths_M_1 = Sequence; + + constexpr auto thread_buffer_desc_m_k = make_naive_tensor_descriptor_packed( + make_tuple(Number{}, Number{})); + constexpr auto thread_buffer_desc_m = + make_naive_tensor_descriptor_packed(make_tuple(Number{})); + constexpr auto thread_buffer_desc_m_1 = make_naive_tensor_descriptor_packed( + make_tuple(Number{}, Number<1>{})); + + StaticBuffer + x_thread_buf; + + StaticBuffer mean_thread_buf; + StaticBuffer var_thread_buf; + StaticBuffer count_thread_buf; + + StaticBuffer + tmp_mean_thread_buf; + StaticBuffer + tmp_var_thread_buf; + StaticBuffer tmp_count_thread_buf; + + const index_t reduceSizePerBlock = K_BlockTileSize * num_k_block_tile_iteration; + + auto threadwise_x_load = ThreadwiseTensorSliceTransfer_v2( + x_grid_desc_m_k, + make_multi_index(blkgroup_id * M_BlockTileSize + thread_m_cluster_id * MThreadSliceSize, + block_local_id * reduceSizePerBlock + + thread_k_cluster_id * KThreadSliceSize)); + + constexpr auto xy_copy_fwd_step_m_k = make_multi_index(0, K_BlockTileSize); + + const auto x_global_val_buf = make_dynamic_buffer( + p_x, x_grid_desc_m_k.GetElementSpaceSize()); + + // Step 1: each workgroup does local welford reduction + + auto threadwise_welford_1 = ThreadwiseWelford1(); + threadwise_welford_1.max_count_ = + get_reduce_count_per_thread(block_local_id, thread_k_cluster_id); + + static_for<0, MThreadSliceSize, 1>{}([&](auto I) { + mean_thread_buf(I) = type_convert(0.0f); + var_thread_buf(I) = type_convert(0.0f); + }); + + for(index_t reducedTiles = 0; reducedTiles < num_k_block_tile_iteration; ++reducedTiles) + { + threadwise_x_load.Run(x_grid_desc_m_k, + x_global_val_buf, + thread_buffer_desc_m_k, + make_tuple(I0, I0), + x_thread_buf); + + threadwise_x_load.MoveSrcSliceWindow(x_grid_desc_m_k, xy_copy_fwd_step_m_k); + threadwise_welford_1.Run(x_thread_buf, mean_thread_buf, var_thread_buf); + } + + static_for<0, MThreadSliceSize, 1>{}([&](auto I) { + if constexpr(I > 0) + block_sync_lds(); + + count_thread_buf(I) = threadwise_welford_1.cur_count_; + BlockwiseWelford1::Run(mean_thread_buf(I), var_thread_buf(I), count_thread_buf(I)); + }); + + // Step 2: each workgroup writes its local welford result to workspace memory + + auto mean_global_val_buf = + make_dynamic_buffer( + p_welford_mean, mean_var_count_grid_desc_m_g.GetElementSpaceSize()); + + auto var_global_val_buf = + make_dynamic_buffer( + p_welford_variance, mean_var_count_grid_desc_m_g.GetElementSpaceSize()); + + auto count_global_val_buf = + make_dynamic_buffer( + p_welford_count, mean_var_count_grid_desc_m_g.GetElementSpaceSize()); + + auto threadwise_mean_var_store_m_g = + ThreadwiseTensorSliceTransfer_v1r3, + 0, + 1, + InMemoryDataOperationEnum::Set, + 1, + true>( + mean_var_count_grid_desc_m_g, + make_multi_index(blkgroup_id * M_BlockTileSize + + thread_m_cluster_id * MThreadSliceSize, + block_local_id), + PassThroughOp{}); + + auto threadwise_count_store_m_g = + ThreadwiseTensorSliceTransfer_v1r3, + 0, + 1, + InMemoryDataOperationEnum::Set, + 1, + true>( + mean_var_count_grid_desc_m_g, + make_multi_index(blkgroup_id * M_BlockTileSize + + thread_m_cluster_id * MThreadSliceSize, + block_local_id), + PassThroughOp{}); + + if(thread_k_cluster_id == 0) + { + threadwise_mean_var_store_m_g.Run(thread_buffer_desc_m_1, + make_tuple(I0, I0), + mean_thread_buf, + mean_var_count_grid_desc_m_g, + mean_global_val_buf); + + threadwise_mean_var_store_m_g.Run(thread_buffer_desc_m_1, + make_tuple(I0, I0), + var_thread_buf, + mean_var_count_grid_desc_m_g, + var_global_val_buf); + + threadwise_count_store_m_g.Run(thread_buffer_desc_m_1, + make_tuple(I0, I0), + count_thread_buf, + mean_var_count_grid_desc_m_g, + count_global_val_buf); + }; + + gms_barrier(&p_control[blkgroup_id * 2]); + + if(block_local_id == 0) + gms_reset(&p_control[blkgroup_id * 2]); + + // Step 3: each workgroup reads welford results from workspace memory and does final welford + // reduction + + auto threadwise_mean_var_load_m_k = + ThreadwiseTensorSliceTransfer_v2, + 0, + 1, + 1, + true>( + mean_var_count_grid_desc_m_k, + make_multi_index(blkgroup_id * M_BlockTileSize + + thread_m_cluster_id * MThreadSliceSize, + thread_k_cluster_id * 1)); + + auto threadwise_count_load_m_k = + ThreadwiseTensorSliceTransfer_v2, + 0, + 1, + 1, + true>( + mean_var_count_grid_desc_m_k, + make_multi_index(blkgroup_id * M_BlockTileSize + + thread_m_cluster_id * MThreadSliceSize, + thread_k_cluster_id * 1)); + + static_for<0, MThreadSliceSize, 1>{}([&](auto I) { + mean_thread_buf(I) = type_convert(0.0f); + var_thread_buf(I) = type_convert(0.0f); + count_thread_buf(I) = 0; + }); + + constexpr auto mean_var_count_read_fwd_step_m_k = make_multi_index(0, KThreadClusterSize); + + int32_t reducedSize = 0; + while(reducedSize < blkgroup_size) + { + threadwise_mean_var_load_m_k.Run(mean_var_count_grid_desc_m_k, + mean_global_val_buf, + thread_buffer_desc_m_1, + make_tuple(I0, I0), + tmp_mean_thread_buf); + + threadwise_mean_var_load_m_k.Run(mean_var_count_grid_desc_m_k, + var_global_val_buf, + thread_buffer_desc_m_1, + make_tuple(I0, I0), + tmp_var_thread_buf); + + threadwise_count_load_m_k.Run(mean_var_count_grid_desc_m_k, + count_global_val_buf, + thread_buffer_desc_m_1, + make_tuple(I0, I0), + tmp_count_thread_buf); + + ThreadwiseWelford2::Run(tmp_mean_thread_buf, + tmp_var_thread_buf, + tmp_count_thread_buf, + mean_thread_buf, + var_thread_buf, + count_thread_buf); + + reducedSize += KThreadClusterSize; + + threadwise_mean_var_load_m_k.MoveSrcSliceWindow(mean_var_count_grid_desc_m_k, + mean_var_count_read_fwd_step_m_k); + threadwise_count_load_m_k.MoveSrcSliceWindow(mean_var_count_grid_desc_m_k, + mean_var_count_read_fwd_step_m_k); + }; + + static_for<0, MThreadSliceSize, 1>{}([&](auto I) { + if constexpr(I > 0) + block_sync_lds(); + + BlockwiseWelford2::Run(mean_thread_buf(I), var_thread_buf(I), count_thread_buf(I)); + }); + + // Step 4: do normalization using the mean/variance + + StaticBuffer scale_thread_buf; + + StaticBuffer bias_thread_buf; + + StaticBuffer + y_thread_buf; + + auto threadwise_y_store = + ThreadwiseTensorSliceTransfer_v1r3( + y_grid_desc_m_k, + make_multi_index( + blkgroup_id * M_BlockTileSize + thread_m_cluster_id * MThreadSliceSize, + block_local_id * reduceSizePerBlock + thread_k_cluster_id * KThreadSliceSize), + y_elementwise_op); + + auto threadwise_scale_load = + ThreadwiseTensorSliceTransfer_v2, + 0, + ScaleSrcVectorSize, + 1, + true>( + scale_grid_desc_m, + make_multi_index(blkgroup_id * M_BlockTileSize + + thread_m_cluster_id * MThreadSliceSize)); + + auto threadwise_bias_load = ThreadwiseTensorSliceTransfer_v2, + 0, + BiasSrcVectorSize, + 1, + true>( + bias_grid_desc_m, + make_multi_index(blkgroup_id * M_BlockTileSize + + thread_m_cluster_id * MThreadSliceSize)); + + const auto scale_global_val_buf = make_dynamic_buffer( + p_scale, scale_grid_desc_m.GetElementSpaceSize()); + + const auto bias_global_val_buf = make_dynamic_buffer( + p_bias, bias_grid_desc_m.GetElementSpaceSize()); + + auto y_global_val_buf = make_dynamic_buffer( + p_y, y_grid_desc_m_k.GetElementSpaceSize()); + + threadwise_scale_load.Run(scale_grid_desc_m, + scale_global_val_buf, + thread_buffer_desc_m, + make_tuple(I0), + scale_thread_buf); + + threadwise_bias_load.Run(bias_grid_desc_m, + bias_global_val_buf, + thread_buffer_desc_m, + make_tuple(I0), + bias_thread_buf); + + threadwise_x_load.SetSrcSliceOrigin( + x_grid_desc_m_k, + make_multi_index(blkgroup_id * M_BlockTileSize + thread_m_cluster_id * MThreadSliceSize, + block_local_id * reduceSizePerBlock + + thread_k_cluster_id * KThreadSliceSize)); + + for(index_t reducedTiles = 0; reducedTiles < num_k_block_tile_iteration; ++reducedTiles) + { + threadwise_x_load.Run(x_grid_desc_m_k, + x_global_val_buf, + thread_buffer_desc_m_k, + make_tuple(I0, I0), + x_thread_buf); + + static_for<0, MThreadSliceSize, 1>{}([&](auto iM) { + AccDataType multiplier = + scale_thread_buf[Number{}] / sqrt(var_thread_buf[iM] + epsilon); + + AccDataType fused_mean_bias = + bias_thread_buf[Number{}] - mean_thread_buf[iM] * multiplier; + + static_for<0, KThreadSliceSize, 1>{}([&](auto iK) { + constexpr auto offset = + thread_buffer_desc_m_k.CalculateOffset(make_tuple(iM, iK)); + + // normalize + y_thread_buf(Number{}) = + x_thread_buf[Number{}] * multiplier + fused_mean_bias; + }); + }); + + threadwise_y_store.Run(thread_buffer_desc_m_k, + make_tuple(I0, I0), + y_thread_buf, + y_grid_desc_m_k, + y_global_val_buf); + + threadwise_x_load.MoveSrcSliceWindow(x_grid_desc_m_k, xy_copy_fwd_step_m_k); + threadwise_y_store.MoveDstSliceWindow(y_grid_desc_m_k, xy_copy_fwd_step_m_k); + } + + // Step 5: update the moving average of mean and variance (optional) + + if(updateMovingAverage && block_local_id == 0 && thread_k_cluster_id == 0) + { + StaticBuffer + running_mean_thread_buf; + StaticBuffer + running_var_thread_buf; + + auto running_mean_global_buf = make_dynamic_buffer( + resultRunningMean, mean_var_grid_desc_m.GetElementSpaceSize()); + + auto running_var_global_buf = make_dynamic_buffer( + resultRunningVariance, mean_var_grid_desc_m.GetElementSpaceSize()); + + auto threadwise_mean_var_load = + ThreadwiseTensorSliceTransfer_v2, + 0, + MeanVarSrcDstVectorSize, + 1, + true>( + mean_var_grid_desc_m, + make_multi_index(blkgroup_id * M_BlockTileSize + + thread_m_cluster_id * MThreadSliceSize)); + + threadwise_mean_var_load.Run(mean_var_grid_desc_m, + running_mean_global_buf, + thread_buffer_desc_m, + make_tuple(I0), + running_mean_thread_buf); + + threadwise_mean_var_load.Run(mean_var_grid_desc_m, + running_var_global_buf, + thread_buffer_desc_m, + make_tuple(I0), + running_var_thread_buf); + + AccDataType oneMinusAverageFactor = type_convert(1.0) - averageFactor; + + static_for<0, MThreadSliceSize, 1>{}([&](auto I) { + running_mean_thread_buf(I) = running_mean_thread_buf[I] * oneMinusAverageFactor + + mean_thread_buf[I] * averageFactor; + running_var_thread_buf(I) = running_var_thread_buf[I] * oneMinusAverageFactor + + var_thread_buf[I] * averageFactor; + }); + + auto threadwise_mean_var_store = + ThreadwiseTensorSliceTransfer_v1r3, + 0, + MeanVarSrcDstVectorSize, + InMemoryDataOperationEnum::Set, + 1, + true>( + mean_var_grid_desc_m, + make_multi_index(blkgroup_id * M_BlockTileSize + + thread_m_cluster_id * MThreadSliceSize), + PassThroughOp{}); + + threadwise_mean_var_store.Run(thread_buffer_desc_m, + make_tuple(I0), + running_mean_thread_buf, + mean_var_grid_desc_m, + running_mean_global_buf); + + threadwise_mean_var_store.Run(thread_buffer_desc_m, + make_tuple(I0), + running_var_thread_buf, + mean_var_grid_desc_m, + running_var_global_buf); + }; + + // Step 6: save mean and inv-variance (optional) + + if(saveMeanInvVariance && block_local_id == 0 && thread_k_cluster_id == 0) + { + auto result_mean_global_buf = make_dynamic_buffer( + resultSaveMean, mean_var_grid_desc_m.GetElementSpaceSize()); + + auto result_inv_var_global_buf = make_dynamic_buffer( + resultSaveInvVariance, mean_var_grid_desc_m.GetElementSpaceSize()); + + // calculate inv-variance as 1/sqrt(epsilon+variance), stored in place of variance + static_for<0, MThreadSliceSize, 1>{}([&](auto I) { + var_thread_buf(I) = + type_convert(1.0f) / sqrt(epsilon + var_thread_buf[I]); + }); + + auto threadwise_mean_inv_var_store = + ThreadwiseTensorSliceTransfer_v1r3, + 0, + MeanVarSrcDstVectorSize, + InMemoryDataOperationEnum::Set, + 1, + true>( + mean_var_grid_desc_m, + make_multi_index(blkgroup_id * M_BlockTileSize + + thread_m_cluster_id * MThreadSliceSize), + PassThroughOp{}); + + threadwise_mean_inv_var_store.Run(thread_buffer_desc_m, + make_tuple(I0), + mean_thread_buf, + mean_var_grid_desc_m, + result_mean_global_buf); + + threadwise_mean_inv_var_store.Run(thread_buffer_desc_m, + make_tuple(I0), + var_thread_buf, + mean_var_grid_desc_m, + result_inv_var_global_buf); + }; + } +}; // namespace ck + +} // namespace ck diff --git a/include/ck/tensor_operation/gpu/grid/batchnorm_multiblock/gridwise_multiblock_reduce_second_half_batchnorm_backward_final.hpp b/include/ck/tensor_operation/gpu/grid/batchnorm_multiblock/gridwise_multiblock_reduce_second_half_batchnorm_backward_final.hpp index a72a4ee068f2bbcb1a915df0368c9b05caa298f9..4e182ec29ddd1be0ff83b72a1da956ab0353a37a 100644 --- a/include/ck/tensor_operation/gpu/grid/batchnorm_multiblock/gridwise_multiblock_reduce_second_half_batchnorm_backward_final.hpp +++ b/include/ck/tensor_operation/gpu/grid/batchnorm_multiblock/gridwise_multiblock_reduce_second_half_batchnorm_backward_final.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #pragma once @@ -118,8 +118,8 @@ struct GridwiseReduceSecondHalfBatchNormBackwardFinal static constexpr auto thread_cluster_desc = make_cluster_descriptor(ThreadClusterLengths_M_K{}, ThreadClusterArrangeOrder{}); - using ThreadReduceSrcDesc_M_1 = decltype( - make_naive_tensor_descriptor_packed(make_tuple(Number{}, Number<1>{}))); + using ThreadReduceSrcDesc_M_1 = decltype(make_naive_tensor_descriptor_packed( + make_tuple(Number{}, Number<1>{}))); using ThreadReduceDstDesc_M = decltype(make_naive_tensor_descriptor_packed(make_tuple(Number{}))); diff --git a/include/ck/tensor_operation/gpu/grid/batchnorm_multiblock/gridwise_multiblock_welford_first_half.hpp b/include/ck/tensor_operation/gpu/grid/batchnorm_multiblock/gridwise_multiblock_welford_first_half.hpp index 08cb0dd191303813a9283b19e38e2952fa0ccf27..a82a173500d21305cf5b643553d4ddf544a62d0d 100644 --- a/include/ck/tensor_operation/gpu/grid/batchnorm_multiblock/gridwise_multiblock_welford_first_half.hpp +++ b/include/ck/tensor_operation/gpu/grid/batchnorm_multiblock/gridwise_multiblock_welford_first_half.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #pragma once @@ -161,7 +161,7 @@ struct GridwiseMultiblockWelfordFirstHalf PassThroughOp, ThreadBufferLengths_M_1, Sequence<0, 1>, - 1, + 0, 1, InMemoryDataOperationEnum::Set, 1, @@ -180,7 +180,7 @@ struct GridwiseMultiblockWelfordFirstHalf PassThroughOp, ThreadBufferLengths_M_1, Sequence<0, 1>, - 1, + 0, 1, InMemoryDataOperationEnum::Set, 1, diff --git a/include/ck/tensor_operation/gpu/grid/batchnorm_multiblock/gridwise_multiblock_welford_second_half_batchnorm_forward_final.hpp b/include/ck/tensor_operation/gpu/grid/batchnorm_multiblock/gridwise_multiblock_welford_second_half_batchnorm_forward_final_obsolete.hpp similarity index 97% rename from include/ck/tensor_operation/gpu/grid/batchnorm_multiblock/gridwise_multiblock_welford_second_half_batchnorm_forward_final.hpp rename to include/ck/tensor_operation/gpu/grid/batchnorm_multiblock/gridwise_multiblock_welford_second_half_batchnorm_forward_final_obsolete.hpp index 548d7fd40ac70fe7fde5e5973a177600b0941c5c..672be91a79ed87322ce228cf558e33b505fe6f5c 100644 --- a/include/ck/tensor_operation/gpu/grid/batchnorm_multiblock/gridwise_multiblock_welford_second_half_batchnorm_forward_final.hpp +++ b/include/ck/tensor_operation/gpu/grid/batchnorm_multiblock/gridwise_multiblock_welford_second_half_batchnorm_forward_final_obsolete.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #pragma once @@ -33,7 +33,6 @@ __global__ void kernel_welford_second_half_batchnorm_forward_final( const MeanVarGridDesc_M mean_var_grid_desc_m, index_t blkgroup_size, index_t num_xy_k_block_tile_iteration, - index_t num_mean_var_count_k_block_tile_iteration, AccDataType epsilon, const MeanVarDataType* const __restrict__ p_in_welford_mean, const MeanVarDataType* const __restrict__ p_in_welford_variance, @@ -59,7 +58,6 @@ __global__ void kernel_welford_second_half_batchnorm_forward_final( mean_var_grid_desc_m, blkgroup_size, num_xy_k_block_tile_iteration, - num_mean_var_count_k_block_tile_iteration, epsilon, p_in_welford_mean, p_in_welford_variance, @@ -123,8 +121,8 @@ struct GridwiseWelfordSecondHalfBatchNormForwardFinal static constexpr auto thread_cluster_desc = make_cluster_descriptor(ThreadClusterLengths_M_K{}, ThreadClusterArrangeOrder{}); - using ThreadReduceSrcDesc_M_1 = decltype( - make_naive_tensor_descriptor_packed(make_tuple(Number{}, Number<1>{}))); + using ThreadReduceSrcDesc_M_1 = decltype(make_naive_tensor_descriptor_packed( + make_tuple(Number{}, Number<1>{}))); using ThreadReduceDstDesc_M = decltype(make_naive_tensor_descriptor_packed(make_tuple(Number{}))); @@ -152,7 +150,6 @@ struct GridwiseWelfordSecondHalfBatchNormForwardFinal const MeanVarGridDesc_M& mean_var_grid_desc_m, index_t blkgroup_size, index_t num_xy_k_block_tile_iteration, - index_t num_mean_var_count_k_block_tile_iteration, AccDataType epsilon, const MeanVarDataType* const __restrict__ p_in_welford_mean, const MeanVarDataType* const __restrict__ p_in_welford_variance, @@ -223,7 +220,7 @@ struct GridwiseWelfordSecondHalfBatchNormForwardFinal decltype(thread_buffer_desc_m_1), ThreadBufferLengths_M_1, Sequence<0, 1>, - 1, + 0, 1, 1, true>( @@ -239,7 +236,7 @@ struct GridwiseWelfordSecondHalfBatchNormForwardFinal decltype(thread_buffer_desc_m_1), ThreadBufferLengths_M_1, Sequence<0, 1>, - 1, + 0, 1, 1, true>( @@ -257,9 +254,6 @@ struct GridwiseWelfordSecondHalfBatchNormForwardFinal const auto welford_count_global_val_buf = make_dynamic_buffer( p_in_welford_count, mean_var_count_grid_desc_m_k.GetElementSpaceSize()); - constexpr auto mean_var_count_thread_copy_step_m_k = - make_multi_index(0, KThreadClusterSize * 1); - // Step 1: do final welford reduction to get mean and variance static_for<0, MThreadSliceSize, 1>{}([&](auto I) { @@ -268,8 +262,11 @@ struct GridwiseWelfordSecondHalfBatchNormForwardFinal welford_count_thread_buf(I) = 0; }); - for(index_t reducedTiles = 0; reducedTiles < num_mean_var_count_k_block_tile_iteration; - ++reducedTiles) + constexpr auto mean_var_count_thread_copy_step_m_k = + make_multi_index(0, KThreadClusterSize); + + int32_t reducedSize = 0; + while(reducedSize < blkgroup_size) { threadwise_mean_var_load_m_k.Run(mean_var_count_grid_desc_m_k, welford_mean_global_val_buf, @@ -296,6 +293,8 @@ struct GridwiseWelfordSecondHalfBatchNormForwardFinal welford_var_thread_buf, welford_count_thread_buf); + reducedSize += KThreadClusterSize; + threadwise_mean_var_load_m_k.MoveSrcSliceWindow(mean_var_count_grid_desc_m_k, mean_var_count_thread_copy_step_m_k); threadwise_count_load_m_k.MoveSrcSliceWindow(mean_var_count_grid_desc_m_k, diff --git a/include/ck/tensor_operation/gpu/grid/batchnorm_multiblock/gridwise_multiblock_welford_second_half_multiblock_reduce_first_half.hpp b/include/ck/tensor_operation/gpu/grid/batchnorm_multiblock/gridwise_multiblock_welford_second_half_multiblock_reduce_first_half.hpp index 42b7e172b239e69ae0ca840dd80d8bb203a5480f..2d5dc90bfb3744709bc51219c634d875994dda65 100644 --- a/include/ck/tensor_operation/gpu/grid/batchnorm_multiblock/gridwise_multiblock_welford_second_half_multiblock_reduce_first_half.hpp +++ b/include/ck/tensor_operation/gpu/grid/batchnorm_multiblock/gridwise_multiblock_welford_second_half_multiblock_reduce_first_half.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #pragma once @@ -115,8 +115,8 @@ struct GridwiseWelfordSecondHalfReduceFirstHalf using ThreadReduceSrcDesc_M_K = decltype(make_naive_tensor_descriptor_packed( make_tuple(Number{}, Number{}))); - using ThreadReduceSrcDesc_M_1 = decltype( - make_naive_tensor_descriptor_packed(make_tuple(Number{}, Number<1>{}))); + using ThreadReduceSrcDesc_M_1 = decltype(make_naive_tensor_descriptor_packed( + make_tuple(Number{}, Number<1>{}))); using ThreadReduceDstDesc_M = decltype(make_naive_tensor_descriptor_packed(make_tuple(Number{}))); diff --git a/include/ck/tensor_operation/gpu/grid/block_to_ctile_map.hpp b/include/ck/tensor_operation/gpu/grid/block_to_ctile_map.hpp index ad91c3c6872d91c0c5fc73021f2464df59b1e33f..7bb47e9d3c4b3a38c14ced3059d9efe3c3269763 100644 --- a/include/ck/tensor_operation/gpu/grid/block_to_ctile_map.hpp +++ b/include/ck/tensor_operation/gpu/grid/block_to_ctile_map.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #pragma once @@ -7,6 +7,8 @@ #include "ck/utility/number.hpp" #include "ck/tensor_description/tensor_adaptor.hpp" #include "ck/tensor_description/multi_index_transform_helper.hpp" +#include +#include namespace ck { @@ -134,6 +136,14 @@ struct BlockToCTileMap_M00_N0_M01Adapt { } + template + __host__ __device__ BlockToCTileMap_M00_N0_M01Adapt(const CGridDesc_M_N& c_grid_desc_m_n, + index_t M01 = 8) + : BlockToCTileMap_M00_N0_M01Adapt( + c_grid_desc_m_n.GetLength(I0), c_grid_desc_m_n.GetLength(I1), M01) + { + } + __host__ static constexpr index_t CalculateGridSize(index_t M, index_t N) { const auto M0 = math::integer_divide_ceil(M, MPerBlock); @@ -142,6 +152,18 @@ struct BlockToCTileMap_M00_N0_M01Adapt return M0 * N0; } + template + __host__ static constexpr index_t CalculateGridSize(const CGridDesc_M_N& c_grid_desc_m_n) + { + return CalculateGridSize(c_grid_desc_m_n.GetLength(I0), c_grid_desc_m_n.GetLength(I1)); + } + + template + __host__ bool CheckValidity(const CGridDesc_M_N& /* c_grid_desc_m_n */) const + { + return true; + } + template __host__ __device__ constexpr auto CalculateBottomIndex(const TopIdx& idx_top) const { @@ -222,30 +244,12 @@ struct BlockToCTileMap_M00_N0_M01Adapt index_t M01_; }; +// keep the redundant type argument for backward compatibility template struct BlockToCTileMap_M00_N0_M01Adapt : BlockToCTileMap_M00_N0_M01Adapt { - using Parent = BlockToCTileMap_M00_N0_M01Adapt; - - using Parent::I0; - using Parent::I1; - - using Parent::Parent; - using Parent::operator=; - - __host__ __device__ BlockToCTileMap_M00_N0_M01Adapt(const CGridDesc_M_N& c_grid_desc_m_n, - index_t M01 = 8) - : Parent(c_grid_desc_m_n.GetLength(I0), c_grid_desc_m_n.GetLength(I1), M01) - { - } - - __host__ static constexpr index_t CalculateGridSize(const CGridDesc_M_N& c_grid_desc_m_n) - { - return Parent::CalculateGridSize(c_grid_desc_m_n.GetLength(I0), - c_grid_desc_m_n.GetLength(I1)); - } - - __host__ bool CheckValidity(const CGridDesc_M_N& /* c_grid_desc_m_n */) const { return true; } + using BlockToCTileMap_M00_N0_M01Adapt:: + BlockToCTileMap_M00_N0_M01Adapt; }; // 2D slices of column-vectors in 3D space @@ -583,7 +587,8 @@ struct OffsettedBlockToCTileMap { using underlying_type = UnderlyingBlockToCTileMap; - OffsettedBlockToCTileMap(UnderlyingBlockToCTileMap block_to_ctile_map, index_t block_start) + __host__ __device__ OffsettedBlockToCTileMap(UnderlyingBlockToCTileMap block_to_ctile_map, + index_t block_start) { block_to_ctile_map_ = block_to_ctile_map; block_start_ = block_start; @@ -667,4 +672,406 @@ struct BlockToCTileMap_3DGrid_KSplit } }; +enum StreamKReductionStrategy +{ + Atomic = 0, // sk block use atomic to do reduction + Reduction, // let some workgroup responsible for doing the reduction operation +}; + +template +struct BlockToCTileMap_GemmStreamK +{ + static constexpr uint32_t min_k_iters_per_sk_block = 2; + static constexpr uint32_t MPerBlock = MPerBlock_; + static constexpr uint32_t NPerBlock = NPerBlock_; + static constexpr uint32_t KPerBlock = KPerBlock_; + static constexpr StreamKReductionStrategy ReductionStrategy = ReductionStrategy_; + static constexpr uint32_t tile_swizzle_sub_m = TileSwizzleSubM_; + + //-------------------------------------- + // pass to device + uint32_t sk_num_blocks; + uint32_t sk_num_big_blocks; + uint32_t dp_start_block_idx; + uint32_t reduction_start_block_idx; + uint32_t k_iters_per_big_block; + MDiv2 n_tiles; + MDiv k_iters_per_tile; + MDiv eqav_tiles_big; // for reduction + MDiv eqav_tiles_little; // for reduction + + // MDiv tile_swizzle_sub_m_rem; + //-------------------------------------- + + // prefer construct on host + BlockToCTileMap_GemmStreamK(uint32_t m, + uint32_t n, + uint32_t k, + uint32_t num_cu, + uint32_t occupancy, + uint32_t sk_blocks = 0xffffffff) + { + uint32_t num_tiles = + math::integer_divide_ceil(m, MPerBlock) * math::integer_divide_ceil(n, NPerBlock); + k_iters_per_tile = MDiv(math::integer_divide_ceil(k, KPerBlock)); + + // one cu can hold one wg at one time, from the whole chip's point of view + // if number of wg is same as num_cu, we call it 1 dispatch + // if number of wg is 2x num_cu, we call it 2 dispatches. + // one dispatch can deliver wg same as num_cu (full dispatch), or less than num_cu (partial + // dispatch) + // + uint32_t full_dispatches = num_tiles / num_cu; + uint32_t full_dispatch_tiles = full_dispatches * num_cu; + uint32_t partial_dispatche_tiles = num_tiles - full_dispatch_tiles; + + uint32_t sk_occupancy = occupancy; + uint32_t dp_tiles = full_dispatch_tiles; + uint32_t sk_tiles = partial_dispatche_tiles; + + if(full_dispatches < occupancy) + { + // in this case, we allocate all blocks as sk blocks + // sk_occupancy = occupancy - full_dispatches; + sk_occupancy = 1; // TODO: single occ seems better + dp_tiles = full_dispatch_tiles; + sk_tiles = partial_dispatche_tiles; + } + else if((occupancy > 1) && (full_dispatches % occupancy == occupancy - 1)) + { + // e.g. occupancy = 2, full_dispatches = 3, 5, 7 ... + // occupancy = 3, full_dispatches = 5, 8, 11 ... + // occupancy = 4, full_dispatches = 7, 11 ... + sk_occupancy = 1; // left 1 slot for sk occupancy + dp_tiles = full_dispatch_tiles; + sk_tiles = partial_dispatche_tiles; + } + else + { + // others, we reduce 1 dispatch from dp, together with partial dispatch, + // to construct sk dispatch + sk_occupancy = occupancy - ((full_dispatches - 1) % occupancy); + dp_tiles = full_dispatch_tiles - num_cu; + sk_tiles = partial_dispatche_tiles + num_cu; + } + + // uint32_t dp_iters_per_block = k_iters_per_tile.get(); + uint32_t sk_total_iters = k_iters_per_tile.get() * sk_tiles; + uint32_t dp_num_blocks = 0; + + { + uint32_t min_sk_tiles = (sk_tiles >= num_cu) ? num_cu : (sk_tiles + 1); + uint32_t max_sk_tiles = + (sk_tiles >= num_cu) ? num_cu * sk_occupancy + : math::min(num_cu, sk_total_iters / min_k_iters_per_sk_block); + + // if use dp for sk-block, how many iters do we need + uint32_t dp_for_sk_iters = k_iters_per_tile.get(); + + uint32_t best_sk_score = + std::numeric_limits::max(); // we need to find the smallest sk iters + for(uint32_t tentative_sk_blocks = min_sk_tiles; tentative_sk_blocks < max_sk_tiles; + tentative_sk_blocks++) + { + uint32_t tentative_sk_iters_per_block = + (sk_total_iters + tentative_sk_blocks - 1) / tentative_sk_blocks; + uint32_t tentative_sk_iters = tentative_sk_iters_per_block; + uint32_t sk_blocks_per_tile = (tentative_sk_blocks + sk_tiles - 1) / sk_tiles; + + // TODO: carefully adjust this parameter + // the more sk_blocks_per_tile, the worse the overhead + uint32_t cross_sk_blocks_overhead = sk_blocks_per_tile; + if(tentative_sk_blocks % sk_tiles != 0) + { + // penalty for uneven divide + cross_sk_blocks_overhead += + sk_blocks_per_tile * tentative_sk_iters_per_block / 50; + } + + uint32_t tentative_sk_score = tentative_sk_iters + cross_sk_blocks_overhead; + + if(tentative_sk_score < best_sk_score) + { + best_sk_score = tentative_sk_score; + sk_num_blocks = tentative_sk_blocks; + } + } + + if(best_sk_score >= dp_for_sk_iters) + { + sk_num_blocks = 0; + } + + // give a chance to control num of sk blocks + sk_num_blocks = sk_blocks != 0xffffffff ? sk_blocks : sk_num_blocks; + + if(sk_num_blocks == 0) + { + sk_num_big_blocks = 0; + k_iters_per_big_block = 0; + + dp_num_blocks = num_tiles; // all tile to be dp block + dp_start_block_idx = 0; + sk_total_iters = 0; // clear this tiles + } + else + { + // k_iters_per_sk_block is the floor of avg each ck block loop over tiles. + // we need to decide how many iters for each sk block + // let m = k_iters_per_sk_block + // some of the sk block (little) will cover m iters, some (big) will cover m+1 + // we have + // 1) l + b = sk_blocks + // 2) l * m + b * (m + 1) = sk_total_iters + // => (l + b) * m + b = sk_total_iters + // => sk_blocks * m + b = sk_total_iters + // => b = sk_total_iters - m * sk_blocks + // NOTE: big could be zero + uint32_t k_iters_per_sk_block = sk_total_iters / sk_num_blocks; + sk_num_big_blocks = sk_total_iters - k_iters_per_sk_block * sk_num_blocks; + k_iters_per_big_block = k_iters_per_sk_block + 1; + + dp_num_blocks = dp_tiles; + dp_start_block_idx = (sk_num_blocks + num_cu - 1) / num_cu * num_cu; + } + } + n_tiles = MDiv2(math::integer_divide_ceil(n, NPerBlock)); + reduction_start_block_idx = dp_start_block_idx + dp_num_blocks; + + if constexpr(ReductionStrategy == StreamKReductionStrategy::Reduction) + { + uint32_t upper_big = math::lcm(k_iters_per_big_block, k_iters_per_tile.get()); + uint32_t upper_little = math::lcm(k_iters_per_big_block - 1, k_iters_per_tile.get()); + eqav_tiles_big = MDiv(upper_big / k_iters_per_tile.get()); + eqav_tiles_little = MDiv(upper_little / k_iters_per_tile.get()); + } + +#if 0 + printf("cu:%d, occupancy:%d, grids:%d, num_tiles:%d, dp_tiles:%d, sk_num_big_blocks:%d, " + "sk_num_blocks:%d, " + "sk_total_iters:%d, dp_start_block_idx:%d, dp_iters_per_block:%d, dp_num_blocks:%d, " + "k_iters_per_tile:%d, k_iters_per_big_block:%d, reduction_start_block_idx:%u, " + "sk_tiles:%u, workspace(acc float):%u\n", + num_cu, + occupancy, + get_grid_dims().x, + num_tiles, + dp_tiles, + sk_num_big_blocks, + sk_num_blocks, + sk_total_iters, + dp_start_block_idx, + dp_iters_per_block, + dp_num_blocks, + k_iters_per_tile.get(), + k_iters_per_big_block, + reduction_start_block_idx, + get_sk_tiles(), + get_workspace_size(sizeof(float))); +#endif + } + + __host__ __device__ uint32_t get_sk_total_iters() const + { + uint32_t sk_total_iters = sk_num_big_blocks * k_iters_per_big_block + + (sk_num_blocks - sk_num_big_blocks) * (k_iters_per_big_block - 1); + return sk_total_iters; + } + + __host__ __device__ uint32_t get_sk_tiles() const + { + // tiles for sk + uint32_t sk_total_iters = get_sk_total_iters(); + return k_iters_per_tile.div(sk_total_iters); + } + + __host__ __device__ dim3 get_grid_dims() const + { + if constexpr(ReductionStrategy == StreamKReductionStrategy::Reduction) + { + return dim3(reduction_start_block_idx + get_sk_tiles(), 1, 1); + } + else + return dim3(reduction_start_block_idx, 1, 1); + } + + __device__ uint32_t get_block_idx() const + { + // TODO: swizzle block index for better locality + return __builtin_amdgcn_readfirstlane(blockIdx.x); + } + + __device__ void + get_block_itr(uint32_t block_idx, uint32_t& iter_start, uint32_t& iter_end) const + { + if(block_idx < sk_num_big_blocks) + { + iter_start = block_idx * k_iters_per_big_block; + iter_end = iter_start + k_iters_per_big_block; + } + else if(block_idx < sk_num_blocks) + { + iter_start = (sk_num_big_blocks * k_iters_per_big_block) + + (block_idx - sk_num_big_blocks) * (k_iters_per_big_block - 1); + iter_end = iter_start + (k_iters_per_big_block - 1); + } + else if(block_idx >= dp_start_block_idx) + { + uint32_t sk_total_iters = get_sk_total_iters(); + uint32_t dp_iters_per_block = k_iters_per_tile.get(); + iter_start = sk_total_iters + (block_idx - dp_start_block_idx) * dp_iters_per_block; + iter_end = iter_start + dp_iters_per_block; + } + } + + __device__ uint32_t get_current_iter_length(uint32_t iter_start, + uint32_t iter_end, + uint32_t total_iter_length) const + { + uint32_t iter_length_mod, iter_length_quo /*unused*/; + k_iters_per_tile.divmod(iter_end, iter_length_quo, iter_length_mod); + uint32_t current_iter_length = math::min( + iter_length_mod == 0 ? (iter_end - iter_start) : iter_length_mod, total_iter_length); + return current_iter_length; + } + + __device__ uint32_t get_tile_idx(uint32_t iter) const { return k_iters_per_tile.div(iter); } + + __device__ void + get_tile_idx_with_offset(uint32_t iter, uint32_t& tile_idx, uint32_t& iter_offset) const + { + k_iters_per_tile.divmod(iter, tile_idx, iter_offset); + } + + __device__ auto tile_to_spatial(uint32_t tile_idx, uint32_t m, uint32_t n) const + { + uint32_t m_tile_idx, n_tile_idx; + uint32_t n_tiles_value = math::integer_divide_ceil(n, NPerBlock); + n_tiles.divmod(tile_idx, n_tiles_value, m_tile_idx, n_tile_idx); + + // swizzle tile + uint32_t m_tiles = math::integer_divide_ceil(m, MPerBlock); + + uint32_t tile_swizzle_sub_m_rem = m_tiles % tile_swizzle_sub_m; + + const auto sub_m_adapt = (m_tile_idx < (m_tiles - tile_swizzle_sub_m_rem)) + ? tile_swizzle_sub_m + : tile_swizzle_sub_m_rem; + + uint32_t m_tile_idx_sub0, m_tile_idx_sub1; + m_tile_idx_sub0 = m_tile_idx / tile_swizzle_sub_m; + m_tile_idx_sub1 = m_tile_idx % tile_swizzle_sub_m; + + uint32_t tile_idx_local = n_tile_idx + m_tile_idx_sub1 * n_tiles_value; + + uint32_t m_tile_idx_with_adapt, n_tile_idx_with_adapt; + + n_tile_idx_with_adapt = tile_idx_local / sub_m_adapt; + m_tile_idx_with_adapt = tile_idx_local % sub_m_adapt; + return make_tuple(m_tile_idx_with_adapt + m_tile_idx_sub0 * tile_swizzle_sub_m, + n_tile_idx_with_adapt); + } + + __host__ __device__ uint32_t get_workspace_size_for_acc(uint32_t acc_element_bytes) const + { + static constexpr uint32_t alignment = 128; + uint32_t acc_buffer_bytes = + MPerBlock * NPerBlock * get_total_acc_buffers() * acc_element_bytes; + return (acc_buffer_bytes + alignment - 1) / alignment * alignment; + } + + __host__ __device__ uint32_t get_workspace_size_for_semaphore() const + { + return get_sk_tiles() * sizeof(uint32_t); + } + + __host__ __device__ uint32_t get_workspace_size(uint32_t acc_element_bytes) const + { + return get_workspace_size_for_acc(acc_element_bytes) + get_workspace_size_for_semaphore(); + } + + __host__ __device__ uint32_t get_tile_intersections(uint32_t tiles_, + const MDiv& eqav_tiles_) const + { + uint32_t tile_idx_ = tiles_ == 0 ? 0 : (tiles_ - 1); + uint32_t max_eqav_tiles_ = eqav_tiles_.get() - 1; + uint32_t quo_, rem_; + eqav_tiles_.divmod(tile_idx_, quo_, rem_); + return quo_ * max_eqav_tiles_ + rem_; + } + + __host__ __device__ uint32_t get_tiles_cover_sk_block(uint32_t num_sk_blocks_, + uint32_t iters_per_sk_block_) const + { + return k_iters_per_tile.div(num_sk_blocks_ * iters_per_sk_block_ + k_iters_per_tile.get() - + 1); + } + + __host__ __device__ uint32_t get_total_acc_buffers() const + { + uint32_t tiles_cover_big_blocks = + get_tiles_cover_sk_block(sk_num_big_blocks, k_iters_per_big_block); + uint32_t tiles_cover_little_blocks = + get_tiles_cover_sk_block(sk_num_blocks - sk_num_big_blocks, k_iters_per_big_block - 1); + + uint32_t total_intersec_big = + get_tile_intersections(tiles_cover_big_blocks, eqav_tiles_big); + uint32_t total_intersec_little = + get_tile_intersections(tiles_cover_little_blocks, eqav_tiles_little); + + return sk_num_blocks + total_intersec_big + total_intersec_little; + } + + __device__ uint32_t get_acc_buffer_offset_from_tile(uint32_t tile_idx_) const + { + // TODO: from big to little + uint32_t tiles_cover_big_blocks = + get_tiles_cover_sk_block(sk_num_big_blocks, k_iters_per_big_block); + if(tile_idx_ < tiles_cover_big_blocks) + { + uint32_t touched_sk_blocks = + (tile_idx_ * k_iters_per_tile.get() + k_iters_per_big_block - 1) / + k_iters_per_big_block; + uint32_t current_intersec = get_tile_intersections(tile_idx_, eqav_tiles_big); + return touched_sk_blocks + current_intersec; + } + else + { + uint32_t iters_per_little_sk_block = k_iters_per_big_block - 1; + uint32_t tile_idx_little_reverse = get_sk_tiles() - tile_idx_; + uint32_t touched_sk_blocks = + (tile_idx_little_reverse * k_iters_per_tile.get() + iters_per_little_sk_block - 1) / + iters_per_little_sk_block; + uint32_t current_intersec = + get_tile_intersections(tile_idx_little_reverse, eqav_tiles_little); + return get_total_acc_buffers() - (touched_sk_blocks + current_intersec); + } + } + + __device__ uint32_t get_acc_buffer_offset_from_block(uint32_t block_idx_) const + { + uint32_t iters_per_big_sk_block = k_iters_per_big_block; + uint32_t iters_per_little_sk_block = k_iters_per_big_block - 1; + if(block_idx_ < sk_num_big_blocks) + { + uint32_t touched_tiles = k_iters_per_tile.div(block_idx_ * iters_per_big_sk_block + + k_iters_per_tile.get() - 1); + uint32_t current_intersec = get_tile_intersections(touched_tiles, eqav_tiles_big); + return block_idx_ + current_intersec; + } + else + { + uint32_t block_idx_little_reverse = sk_num_blocks - block_idx_; + uint32_t touched_tiles = k_iters_per_tile.div( + block_idx_little_reverse * iters_per_little_sk_block + k_iters_per_tile.get() - 1); + uint32_t current_intersec = get_tile_intersections(touched_tiles, eqav_tiles_little); + return get_total_acc_buffers() - (block_idx_little_reverse + current_intersec); + } + } +}; + } // namespace ck diff --git a/include/ck/tensor_operation/gpu/grid/gemm_layernorm/gridwise_gemm_multiple_d_welford_first_half_xdl_cshuffle.hpp b/include/ck/tensor_operation/gpu/grid/gemm_layernorm/gridwise_gemm_multiple_d_welford_first_half_xdl_cshuffle.hpp index aa34cfbf84ab9425a236646400c1a965c8004f74..b25f136a3713a412fbcb166f3df304a172ee42f7 100644 --- a/include/ck/tensor_operation/gpu/grid/gemm_layernorm/gridwise_gemm_multiple_d_welford_first_half_xdl_cshuffle.hpp +++ b/include/ck/tensor_operation/gpu/grid/gemm_layernorm/gridwise_gemm_multiple_d_welford_first_half_xdl_cshuffle.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #pragma once @@ -101,8 +101,8 @@ struct GridwiseGemmMultipleDWelfordFirstHalf_xdl_cshuffle using ThisThreadBlock = ThisThreadBlock; - using GridwiseGemmPipe = remove_cvref_t())>; + using GridwiseGemmPipe = remove_cvref_t< + decltype(GridwiseGemmPipeline_Selector())>; __host__ __device__ static constexpr auto GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1() { @@ -346,14 +346,18 @@ struct GridwiseGemmMultipleDWelfordFirstHalf_xdl_cshuffle remove_cvref_t; using DefaultBGridDesc_BK0_N_BK1 = remove_cvref_t; - using EGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock = remove_cvref_t; - using MeanVarGridDescriptor_MBlock_MPerBlock_NBlock = remove_cvref_t; - using CountGridDescriptor_MBlock_MPerBlock_NBlock = remove_cvref_t; - using DsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock = remove_cvref_t; + using EGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock = + remove_cvref_t; + using MeanVarGridDescriptor_MBlock_MPerBlock_NBlock = + remove_cvref_t; + using CountGridDescriptor_MBlock_MPerBlock_NBlock = + remove_cvref_t; + using DsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock = + remove_cvref_t; using DefaultBlock2ETileMap = remove_cvref_t; diff --git a/include/ck/tensor_operation/gpu/grid/gemm_layernorm/gridwise_welford_second_half_layernorm2d.hpp b/include/ck/tensor_operation/gpu/grid/gemm_layernorm/gridwise_welford_second_half_layernorm2d.hpp index fbe89e7e5e5936c2db1470e63f922cc7f6c5ff2a..69468c25befe036123cf1796798e6a59ea47f74b 100644 --- a/include/ck/tensor_operation/gpu/grid/gemm_layernorm/gridwise_welford_second_half_layernorm2d.hpp +++ b/include/ck/tensor_operation/gpu/grid/gemm_layernorm/gridwise_welford_second_half_layernorm2d.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_2d_multiple_reduction_multiblock.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_2d_multiple_reduction_multiblock.hpp index bdebe3816f22d6a4f10e913210abc8c8347e20d0..bd1e0585fc96ad29ee9e5336993f2ba6570fb31d 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_2d_multiple_reduction_multiblock.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_2d_multiple_reduction_multiblock.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_2d_multiple_reduction_threadwise.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_2d_multiple_reduction_threadwise.hpp index 1313ec9435e782ad9cc9b92580464f6d5257b93e..fc4f27e33b15b6bfcc7ce942764f9aa5e26cfe1f 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_2d_multiple_reduction_threadwise.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_2d_multiple_reduction_threadwise.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_2d_reduction_multiblock.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_2d_reduction_multiblock.hpp index 6836a66047531d9eef09dbffc58188f7696a783a..203be3c42d9d23c935d10f34ceb44245f7875ad0 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_2d_reduction_multiblock.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_2d_reduction_multiblock.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_2d_reduction_threadwise.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_2d_reduction_threadwise.hpp index 5986641c6f72326bfa7a718c9bdfce57b1949748..910c926c7e470e0a7daa1cd25f57716f5ce4af6b 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_2d_reduction_threadwise.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_2d_reduction_threadwise.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_batched_gemm_gemm_xdl_cshuffle_v1.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_batched_gemm_gemm_xdl_cshuffle_v1.hpp index fccb127d0f342911ed6bb1f8396c9d9cd0eec921..c2f47bd44435bff2ea75cbbee625895cb0c1bef4 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_batched_gemm_gemm_xdl_cshuffle_v1.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_batched_gemm_gemm_xdl_cshuffle_v1.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #pragma once @@ -102,8 +102,8 @@ struct GridwiseBatchedGemmGemm_Xdl_CShuffle using ThisThreadBlock = ThisThreadBlock; - using GridwiseGemmPipe = remove_cvref_t())>; + using GridwiseGemmPipe = remove_cvref_t< + decltype(GridwiseGemmPipeline_Selector())>; template __host__ __device__ static constexpr auto @@ -286,8 +286,9 @@ struct GridwiseBatchedGemmGemm_Xdl_CShuffle c_grid_desc_m_n); } - using CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock = remove_cvref_t; + using CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock = + remove_cvref_t; using DefaultBlock2CTileMap = remove_cvref_t; diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_batched_gemm_multiple_d_gemm_multiple_d_xdl_cshuffle_v1.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_batched_gemm_multiple_d_gemm_multiple_d_xdl_cshuffle_v1.hpp index b9f4a3080a0e17ca0ba169bc8b83c006acb2305d..d2920570e4f2cefc883aa9bf4bf1d8b845146670 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_batched_gemm_multiple_d_gemm_multiple_d_xdl_cshuffle_v1.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_batched_gemm_multiple_d_gemm_multiple_d_xdl_cshuffle_v1.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #pragma once @@ -67,6 +67,8 @@ template ; + using E1GridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock = + remove_cvref_t; - using D0sGridDescriptor_M0_N0_M1_N1_M2_N2_M3_N3_N4_N5 = remove_cvref_t; + using D0sGridDescriptor_M0_N0_M1_N1_M2_N2_M3_N3_N4_N5 = + remove_cvref_t; - using D1sGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock = remove_cvref_t; + using D1sGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock = + remove_cvref_t; using DefaultBlock2E1TileMap = remove_cvref_t; @@ -710,13 +715,13 @@ struct GridwiseBatchedGemmMultipleDGemmMultipleD_Xdl_CShuffle constexpr auto d0_thread_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5 = make_naive_tensor_descriptor_packed(make_tuple(I1, // MBlockId I1, // NBlockID - I1, // MRepeat - I1, // NRepeat - I1, // MWaveId - I1, // NWaveId - I1, // MPerXdl - I1, // NGroupNum - I1, // NInputNum + m0, // MRepeat + n0, // NRepeat + m1, // MWaveId + n1, // NWaveId + m2, // MPerXdl + n2, // NGroupNum + n3, // NInputNum n4)); // registerNum auto d0s_thread_buf = generate_tuple( @@ -732,8 +737,9 @@ struct GridwiseBatchedGemmMultipleDGemmMultipleD_Xdl_CShuffle const auto wave_id = GetGemm0WaveIdx(); const auto wave_m_n_id = GetGemm0WaveMNIdx(wave_id[I2]); // I2: 0~63 - constexpr auto acc0_thread_desc = make_naive_tensor_descriptor_packed( - make_tuple(Number{}, Number{}, n2, n4)); + static_assert(CDE0BlockTransferSrcScalarPerVector <= n4, + "vector load must be not greater than n4"); + static_assert(n4 % CDE0BlockTransferSrcScalarPerVector == 0); auto d0s_threadwise_copy = generate_tuple( [&](auto i) { @@ -742,10 +748,19 @@ struct GridwiseBatchedGemmMultipleDGemmMultipleD_Xdl_CShuffle A0B0B1DataType, decltype(d0s_griddesc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5[i]), decltype(d0_thread_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5), - Sequence, + Sequence, Sequence<0, 1, 2, 3, 4, 5, 6, 7, 8, 9>, - 9, - n4, + 9, // CDE0BlockTransferSrcVectorDim + CDE0BlockTransferSrcScalarPerVector, 1, false>(d0s_griddesc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5[i], make_multi_index(block_work_idx[I0], // MBlockId @@ -898,66 +913,42 @@ struct GridwiseBatchedGemmMultipleDGemmMultipleD_Xdl_CShuffle blockwise_gemm0, acc0_thread_buf, num_k_block_main_loop); - // bias+gelu + // multiple d + if constexpr(NumD0Tensor) { - static_for<0, Gemm0MXdlPerWave, 1>{}([&](auto mr) { - static_for<0, Gemm0NXdlPerWave, 1>{}([&](auto nr) { - static_for<0, n2, 1>{}([&](auto groupid) { - static_for<0, NumD0Tensor, 1>{}([&](auto i) { - d0s_threadwise_copy(i).Run( - d0s_griddesc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5[i], - d0s_grid_buf[i], - d0_thread_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5, - make_tuple(I0, I0, I0, I0, I0, I0, I0, I0, I0, I0), - d0s_thread_buf(i)); - }); - - static_for<0, n4, 1>{}([&](auto i) { - constexpr index_t c_offset = acc0_thread_desc.CalculateOffset( - make_tuple(mr, nr, groupid, i)); - - // get reference to src data - const auto src_data_refs = generate_tie( - // return type should be lvalue - [&](auto iSrc) -> const auto& { - return d0s_thread_buf[iSrc][i]; - }, - Number{}); - - // get reference to dst data - auto dst_data_refs = generate_tie( - // return type should be lvalue - [&](auto) -> auto& { - return acc0_thread_buf(Number{}); - }, - Number<2>{}); - - unpack2(cde0_element_op, dst_data_refs, src_data_refs); - }); - static_for<0, NumD0Tensor, 1>{}([&](auto i) { - d0s_threadwise_copy(i).MoveSrcSliceWindow( - d0s_griddesc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5[i], - make_multi_index(0, 0, 0, 0, 0, 0, 0, 1, 0, 0)); - }); - }); - static_for<0, NumD0Tensor, 1>{}([&](auto i) { - d0s_threadwise_copy(i).MoveSrcSliceWindow( - d0s_griddesc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5[i], - make_multi_index(0, 0, 0, 1, 0, 0, 0, -n2.value, 0, 0)); - }); - }); - static_for<0, NumD0Tensor, 1>{}([&](auto i) { - d0s_threadwise_copy(i).MoveSrcSliceWindow( - d0s_griddesc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5[i], - make_multi_index(0, 0, 1, -Gemm0NXdlPerWave, 0, 0, 0, 0, 0, 0)); - }); + static_for<0, NumD0Tensor, 1>{}([&](auto i) { + d0s_threadwise_copy(i).Run(d0s_griddesc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5[i], + d0s_grid_buf[i], + d0_thread_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5, + make_tuple(I0, I0, I0, I0, I0, I0, I0, I0, I0, I0), + d0s_thread_buf(i)); + }); + static_for<0, m0 * n0 * n2 * n4, 1>{}([&](auto i) { + // get reference to src data + const auto src_data_refs = generate_tie( + // return type should be lvalue + [&](auto iSrc) -> const auto& { return d0s_thread_buf[iSrc][i]; }, + Number{}); + + // get reference to dst data + auto dst_data_refs = generate_tie( + // return type should be lvalue + [&](auto) -> auto& { return acc0_thread_buf(i); }, + Number<2>{}); + + unpack2(cde0_element_op, dst_data_refs, src_data_refs); }); static_for<0, NumD0Tensor, 1>{}([&](auto i) { d0s_threadwise_copy(i).MoveSrcSliceWindow( d0s_griddesc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5[i], - make_multi_index(0, 1, -Gemm0MXdlPerWave, 0, 0, 0, 0, 0, 0, 0)); + make_multi_index(0, 1, 0, 0, 0, 0, 0, 0, 0, 0)); }); } + else + { + static_for<0, acc0_thread_buf.Size(), 1>{}( + [&](auto i) { cde0_element_op(acc_thread_buf(i), acc0_thread_buf[i]); }); + } // gemm1 { // TODO: explore using dynamic buffer for a1 thread buffer diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_batched_gemm_multiple_d_softmax_gemm_xdl_cshuffle_v1.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_batched_gemm_multiple_d_softmax_gemm_xdl_cshuffle_v1.hpp index 6a6f19d71ef523697d69719de75b70bce3953e7e..18cfeebcf30d5ea8e44c9a5a5ff1cf564fef01ba 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_batched_gemm_multiple_d_softmax_gemm_xdl_cshuffle_v1.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_batched_gemm_multiple_d_softmax_gemm_xdl_cshuffle_v1.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #pragma once @@ -80,7 +80,8 @@ template + int D0sTransferSrcScalarPerVector = 4, + PipelineVersion PipelineVer = PipelineVersion::v1> struct GridwiseBatchedGemmMultipleDSoftmaxGemm_Xdl_CShuffle { static_assert(LoopSched == LoopScheduler::Default, @@ -113,8 +114,8 @@ struct GridwiseBatchedGemmMultipleDSoftmaxGemm_Xdl_CShuffle using ThisThreadBlock = ThisThreadBlock; - using GridwiseGemmPipe = remove_cvref_t())>; + using GridwiseGemmPipe = remove_cvref_t< + decltype(GridwiseGemmPipeline_Selector())>; template __host__ __device__ static constexpr auto @@ -367,12 +368,14 @@ struct GridwiseBatchedGemmMultipleDSoftmaxGemm_Xdl_CShuffle Number{}); } - using D0sGridPointer = decltype(MakeD0sGridPointer()); - using D0sGridDescriptor_M0_N0_M1_N1_M2_N2_M3_N3_N4_N5 = remove_cvref_t; + using D0sGridPointer = decltype(MakeD0sGridPointer()); + using D0sGridDescriptor_M0_N0_M1_N1_M2_N2_M3_N3_N4_N5 = + remove_cvref_t; - using C1GridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock = remove_cvref_t; + using C1GridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock = + remove_cvref_t; using DefaultBlock2CTileMap = remove_cvref_t; @@ -621,13 +624,13 @@ struct GridwiseBatchedGemmMultipleDSoftmaxGemm_Xdl_CShuffle constexpr auto d0_thread_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5 = make_naive_tensor_descriptor_packed(make_tuple(I1, // MBlockId I1, // NBlockID - I1, // MRepeat - I1, // NRepeat - I1, // MWaveId - I1, // NWaveId - I1, // MPerXdl - I1, // NGroupNum - I1, // NInputNum + m0, // MRepeat + n0, // NRepeat + m1, // MWaveId + n1, // NWaveId + m2, // MPerXdl + n2, // NGroupNum + n3, // NInputNum n4)); // registerNum auto d0s_thread_buf = generate_tuple( @@ -644,9 +647,6 @@ struct GridwiseBatchedGemmMultipleDSoftmaxGemm_Xdl_CShuffle const auto wave_id = GetGemm0WaveIdx(); const auto wave_m_n_id = GetGemm0WaveMNIdx(wave_id[I2]); // I2: 0~63 - constexpr auto acc0_thread_desc = make_naive_tensor_descriptor_packed( - make_tuple(Number{}, Number{}, n2, n4)); - auto d0s_threadwise_copy = generate_tuple( [&](auto i) { using D0DataType = remove_cvref_t>; @@ -655,10 +655,19 @@ struct GridwiseBatchedGemmMultipleDSoftmaxGemm_Xdl_CShuffle D0DataType, decltype(d0s_griddesc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5[i]), decltype(d0_thread_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5), - Sequence, + Sequence, Sequence<0, 1, 2, 3, 4, 5, 6, 7, 8, 9>, 9, - n4, + D0sTransferSrcScalarPerVector, 1, false>(d0s_griddesc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5[i], make_multi_index(block_work_idx[I0], // MBlockId @@ -884,62 +893,35 @@ struct GridwiseBatchedGemmMultipleDSoftmaxGemm_Xdl_CShuffle // multiple d if constexpr(NumD0Tensor) { - static_for<0, MXdlPerWave, 1>{}([&](auto mr) { - static_for<0, NXdlPerWave, 1>{}([&](auto nr) { - static_for<0, n2, 1>{}([&](auto groupid) { - static_for<0, NumD0Tensor, 1>{}([&](auto i) { - d0s_threadwise_copy(i).Run( - d0s_griddesc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5[i], - d0s_grid_buf[i], - d0_thread_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5, - make_tuple(I0, I0, I0, I0, I0, I0, I0, I0, I0, I0), - d0s_thread_buf(i)); - }); - - static_for<0, n4, 1>{}([&](auto i) { - constexpr index_t c_offset = acc0_thread_desc.CalculateOffset( - make_tuple(mr, nr, groupid, i)); - - // get reference to src data - const auto src_data_refs = generate_tie( - // return type should be lvalue - [&](auto iSrc) -> const auto& { - return d0s_thread_buf[iSrc][i]; - }, - Number{}); - - // get reference to dst data - auto dst_data_refs = generate_tie( - // return type should be lvalue - [&](auto) -> auto& { - return acc_thread_buf(Number{}); - }, - Number<2>{}); - - unpack2(c0de_element_op, dst_data_refs, src_data_refs); - }); - static_for<0, NumD0Tensor, 1>{}([&](auto i) { - d0s_threadwise_copy(i).MoveSrcSliceWindow( - d0s_griddesc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5[i], - make_multi_index(0, 0, 0, 0, 0, 0, 0, 1, 0, 0)); - }); - }); - static_for<0, NumD0Tensor, 1>{}([&](auto i) { - d0s_threadwise_copy(i).MoveSrcSliceWindow( - d0s_griddesc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5[i], - make_multi_index(0, 0, 0, 1, 0, 0, 0, -n2.value, 0, 0)); - }); - }); - static_for<0, NumD0Tensor, 1>{}([&](auto i) { - d0s_threadwise_copy(i).MoveSrcSliceWindow( - d0s_griddesc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5[i], - make_multi_index(0, 0, 1, -NXdlPerWave, 0, 0, 0, 0, 0, 0)); - }); + static_assert(NXdlPerWave == n0); + static_assert(MXdlPerWave == m0); + + static_for<0, NumD0Tensor, 1>{}([&](auto i) { + d0s_threadwise_copy(i).Run(d0s_griddesc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5[i], + d0s_grid_buf[i], + d0_thread_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5, + make_tuple(I0, I0, I0, I0, I0, I0, I0, I0, I0, I0), + d0s_thread_buf(i)); + }); + static_for<0, m0 * n0 * n2 * n4, 1>{}([&](auto i) { + // get reference to src data + const auto src_data_refs = generate_tie( + // return type should be lvalue + [&](auto iSrc) -> const auto& { return d0s_thread_buf[iSrc][i]; }, + Number{}); + + // get reference to dst data + auto dst_data_refs = generate_tie( + // return type should be lvalue + [&](auto) -> auto& { return acc_thread_buf(i); }, + Number<2>{}); + + unpack2(c0de_element_op, dst_data_refs, src_data_refs); }); static_for<0, NumD0Tensor, 1>{}([&](auto i) { d0s_threadwise_copy(i).MoveSrcSliceWindow( d0s_griddesc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5[i], - make_multi_index(0, 1, -MXdlPerWave, 0, 0, 0, 0, 0, 0, 0)); + make_multi_index(0, 1, 0, 0, 0, 0, 0, 0, 0, 0)); }); } else diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_batched_gemm_softmax_gemm_xdl_cshuffle_v1.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_batched_gemm_softmax_gemm_xdl_cshuffle_v1.hpp index d6d2051113ff4fdebb74689293b00c532aae9c68..f4b82badf14584d5e59fda6dc6751a83aa67b32f 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_batched_gemm_softmax_gemm_xdl_cshuffle_v1.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_batched_gemm_softmax_gemm_xdl_cshuffle_v1.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #pragma once @@ -113,8 +113,8 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle using ThisThreadBlock = ThisThreadBlock; - using GridwiseGemmPipe = remove_cvref_t())>; + using GridwiseGemmPipe = remove_cvref_t< + decltype(GridwiseGemmPipeline_Selector())>; template __host__ __device__ static constexpr auto @@ -300,8 +300,9 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle c_grid_desc_m_n); } - using CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock = remove_cvref_t; + using CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock = + remove_cvref_t; using DefaultBlock2CTileMap = remove_cvref_t; diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_batchnorm_backward_blockwise_welford.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_batchnorm_backward_blockwise_welford.hpp index ede6a96dc9f53045b5fc029f4c4f4f6aa4fa4817..ed1ffdd85765e84feb2f6824a858ab842487c18c 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_batchnorm_backward_blockwise_welford.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_batchnorm_backward_blockwise_welford.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_batchnorm_forward_blockwise_welford.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_batchnorm_forward_blockwise_welford.hpp index 33c45a0f037b4c81fbaf810cc10e55f71a2ab254..b6c83af13a18639cf7f8312dd416dbf4d9cb5297 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_batchnorm_forward_blockwise_welford.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_batchnorm_forward_blockwise_welford.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_elementwise_1d.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_elementwise_1d.hpp index 8b82b65540d1c5612002549f72708be07b649dae..d686c14b350953a48bf132694251ac677fb47c4f 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_elementwise_1d.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_elementwise_1d.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_elementwise_2d.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_elementwise_2d.hpp index 05257d16275e5f9817ec5d9cbbd06b133f0b5e05..bf0e8c186c75cef05b08c7ef9e1bf16c016830a3 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_elementwise_2d.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_elementwise_2d.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// // Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// // Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. // #pragma once diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_elementwise_layernorm_welford_variance.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_elementwise_layernorm_welford_variance.hpp index b09a73590239609908265cd6b72e3e5bb26127aa..3ea72b85345869a938e52f872c2c586c6e36170f 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_elementwise_layernorm_welford_variance.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_elementwise_layernorm_welford_variance.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_bias_add_reduce_xdl_cshuffle_v1.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_bias_add_reduce_xdl_cshuffle_v1.hpp index bebcdceb43521a3b489e0b4d04e847bcc6661b3d..3ced4b9ad61fe0be2783fb6f261b9131fb8d1622 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_bias_add_reduce_xdl_cshuffle_v1.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_bias_add_reduce_xdl_cshuffle_v1.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #pragma once @@ -67,7 +67,7 @@ __global__ void const Block2CTileMap block_2_ctile_map) { #if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \ - defined(__gfx940__)) + defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__)) __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()]; GridwiseGemm::template Run(p_a_grid, @@ -191,8 +191,8 @@ struct GridwiseGemmBiasAddReduce_k0mk1_k0nk1_mn_xdl_cshuffle_v1 using ThisThreadBlock = ThisThreadBlock; - using GridwiseGemmPipe = remove_cvref_t())>; + using GridwiseGemmPipe = remove_cvref_t< + decltype(GridwiseGemmPipeline_Selector())>; __host__ __device__ static constexpr auto GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1() { @@ -346,14 +346,17 @@ struct GridwiseGemmBiasAddReduce_k0mk1_k0nk1_mn_xdl_cshuffle_v1 c_grid_desc_m_n); } - using CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock = remove_cvref_t; + using CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock = + remove_cvref_t; - using C0GridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock = remove_cvref_t; + using C0GridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock = + remove_cvref_t; - using C1GridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock = remove_cvref_t; + using C1GridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock = + remove_cvref_t; using ReduceGridDescriptor_MBlock_MPerBlock = remove_cvref_t; diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_dl_multiple_d.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_dl_multiple_d.hpp index 9c68b4f5c3ffbb13cd44701121fca7cc113d76cd..27f48a84ba72fe284a964ac295540eac1c4f6766 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_dl_multiple_d.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_dl_multiple_d.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_dl_v1r3.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_dl_v1r3.hpp index d46aea5e22d131a588c90fc36189c34bd5755242..1da7236978eb042c4270179dcc502ea3593021a2 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_dl_v1r3.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_dl_v1r3.hpp @@ -274,7 +274,7 @@ struct GridwiseGemmDl_km_kn_mn_v1r3 const auto c_m0_n0_block_cluster_idx = block_2_ctile_map.CalculateBottomIndex(make_multi_index(get_block_1d_id())); - // HACK: this force index data into SGPR + // HACK: this forces index data into SGPR const index_t im0 = __builtin_amdgcn_readfirstlane(c_m0_n0_block_cluster_idx[I0]); const index_t in0 = __builtin_amdgcn_readfirstlane(c_m0_n0_block_cluster_idx[I1]); @@ -472,7 +472,7 @@ struct GridwiseGemmDl_km_kn_mn_v1r3 b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc_k0_n0_n1_k1, b_block_slice_copy_step); - // LDS doubel buffer: load next data from device mem + // LDS double buffer: load next data from device mem a_blockwise_copy.RunRead(a_grid_desc_k0_m0_m1_k1, a_global_buf); b_blockwise_copy.RunRead(b_grid_desc_k0_n0_n1_k1, b_global_buf); @@ -992,7 +992,7 @@ struct GridwiseGemmDl_bkm_bkn_mn_v1r3 b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc_b_k0_n0_n1_k1, b_block_slice_copy_step); - // LDS doubel buffer: load next data from device mem + // LDS double buffer: load next data from device mem a_blockwise_copy.RunRead(a_grid_desc_b_k0_m0_m1_k1, a_global_buf); b_blockwise_copy.RunRead(b_grid_desc_b_k0_n0_n1_k1, b_global_buf); diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_dpp.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_dpp.hpp new file mode 100644 index 0000000000000000000000000000000000000000..6eca77c89cc1751e974d5cac0dd368d55ae37d21 --- /dev/null +++ b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_dpp.hpp @@ -0,0 +1,702 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck/utility/common_header.hpp" +#include "ck/tensor_description/multi_index_transform_helper.hpp" +#include "ck/tensor_description/tensor_descriptor.hpp" +#include "ck/tensor_description/tensor_descriptor_helper.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/grid/block_to_ctile_map.hpp" +#include "ck/tensor_operation/gpu/grid/gridwise_gemm_pipeline_selector.hpp" +#include "ck/tensor_operation/gpu/block/blockwise_gemm_dpp.hpp" +#include "ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v4r1.hpp" +#include "ck/tensor_operation/gpu/device/matrix_padder.hpp" +#include "ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" + +namespace ck { + +template +__global__ void +#if CK_USE_LAUNCH_BOUNDS + __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) +#endif +#if CK_USE_WAVES_PER_EU + __attribute__((amdgpu_waves_per_eu(CK_MIN_WAVES_PER_EU, CK_MAX_WAVES_PER_EU))) +#endif + kernel_gemm_dpp(const typename GridwiseGemm::Argument karg) +{ +#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx1030__) || defined(__gfx1100__) || \ + defined(__gfx1101__) || defined(__gfx1102__)) + __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()]; + + const auto a_grid_desc_ak0_m_ak1 = amd_wave_read_first_lane( + GridwiseGemm::MakeAGridDescriptor_AK0_M_AK1(karg.M, karg.K, karg.AK0, karg.StrideA)); + const auto b_grid_desc_bk0_n_bk1 = amd_wave_read_first_lane( + GridwiseGemm::MakeBGridDescriptor_BK0_N_BK1(karg.K, karg.N, karg.BK0, karg.StrideB)); + const auto c_grid_desc_m_n = amd_wave_read_first_lane( + GridwiseGemm::MakeCGridDescriptor_M_N(karg.M, karg.N, karg.StrideC)); + + GridwiseGemm::template Run(karg.p_a_grid, + karg.p_b_grid, + karg.p_c_grid, + p_shared, + a_grid_desc_ak0_m_ak1, + b_grid_desc_bk0_n_bk1, + c_grid_desc_m_n); +#else + ignore = karg; +#endif +} + +template +struct GridwiseGemm_ak0mak1_bk0nbk1_mn_dpp +{ + static constexpr auto I0 = Number<0>{}; + static constexpr auto I1 = Number<1>{}; + static constexpr auto I2 = Number<2>{}; + static constexpr auto I3 = Number<3>{}; + static constexpr auto I4 = Number<4>{}; + static constexpr auto I5 = Number<5>{}; + + static constexpr auto AK1 = Number{}; + static constexpr auto BK1 = Number{}; + static constexpr auto AK0PerBlock = Number{}; + static constexpr auto BK0PerBlock = Number{}; + + static constexpr auto max_lds_align = math::lcm(AK1, BK1); + + using ThisThreadBlock = ThisThreadBlock; + // return block_id to C matrix tile idx (m0, n0) mapping + using Block2CTileMap = BlockToCTileMap_M00_N0_M01Adapt; + + __host__ static auto CalculateGridSize(index_t M, index_t N) + { + return std::make_tuple(Block2CTileMap::CalculateGridSize(M, N), 1, 1); + } + + __host__ static auto CalculateMPadded(index_t M) + { + return math::integer_divide_ceil(M, MPerBlock) * MPerBlock; + } + + __host__ static auto CalculateNPadded(index_t N) + { + return math::integer_divide_ceil(N, NPerBlock) * NPerBlock; + } + + __host__ static auto CalculateAK0(index_t K) { return math::integer_divide_floor(K, AK1Value); } + __host__ static auto CalculateBK0(index_t K) { return math::integer_divide_floor(K, BK1Value); } + + // Argument + struct Problem + { + __host__ Problem(index_t M_, + index_t N_, + index_t K_, + index_t StrideA_, + index_t StrideB_, + index_t StrideC_) + : M{M_}, + N{N_}, + K{K_}, + StrideA{StrideA_}, + StrideB{StrideB_}, + StrideC{StrideC_}, + MPadded{CalculateMPadded(M_)}, + NPadded{CalculateNPadded(N_)}, + AK0{CalculateAK0(K)}, + BK0{CalculateBK0(K)} + { + } + + __host__ void Print() const + { + std::cout << "problem {" + << "M:" << M << ", " + << "N:" << N << ", " + << "K:" << K << ", " + << "SA:" << StrideA << ", " + << "SB:" << StrideB << ", " + << "SC:" << StrideC << ", " + << "MP:" << MPadded << ", " + << "NP:" << NPadded << ", " + << "AK0:" << AK0 << ", " + << "BK0:" << BK0 << "}" << std::endl; + } + + index_t M; + index_t N; + index_t K; + index_t StrideA; + index_t StrideB; + index_t StrideC; + index_t MPadded; + index_t NPadded; + index_t AK0; + index_t BK0; + }; + + // Argument + struct Argument : public Problem, public tensor_operation::device::BaseArgument + { + __host__ Argument(const ABDataType* p_a_grid_, + const ABDataType* p_b_grid_, + CDataType* p_c_grid_, + index_t M_, + index_t N_, + index_t K_, + index_t StrideA_, + index_t StrideB_, + index_t StrideC_) + : Problem{M_, N_, K_, StrideA_, StrideB_, StrideC_}, + p_a_grid{p_a_grid_}, + p_b_grid{p_b_grid_}, + p_c_grid{p_c_grid_} + { + } + + const ABDataType* p_a_grid; + const ABDataType* p_b_grid; + CDataType* p_c_grid; + }; + + using GridwiseGemmPipe = remove_cvref_t< + decltype(GridwiseGemmPipeline_Selector())>; + + __host__ __device__ static constexpr auto GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1() + { + // A matrix in LDS memory, dst of blockwise copy + constexpr auto a_block_desc_ak0_m_ak1 = [&]() { + if constexpr(ABlockLdsExtraM) + { + return make_naive_tensor_descriptor( + make_tuple(Number{}, Number{}, AK1), + make_tuple(Number{} * AK1, AK1, I1)); + } + else + { + return make_naive_tensor_descriptor_aligned( + make_tuple(Number{}, Number{}, AK1), max_lds_align); + } + }(); + + return a_block_desc_ak0_m_ak1; + } + + __host__ __device__ static constexpr auto GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1() + { + // B matrix in LDS memory, dst of blockwise copy + constexpr auto b_block_desc_bk0_n_bk1 = [&]() { + if constexpr(BBlockLdsExtraN) + { + return make_naive_tensor_descriptor( + make_tuple(Number{}, Number{}, BK1), + make_tuple(Number{} * BK1, BK1, I1)); + } + else + { + return make_naive_tensor_descriptor_aligned( + make_tuple(Number{}, Number{}, BK1), max_lds_align); + } + }(); + + return b_block_desc_bk0_n_bk1; + } + + __host__ __device__ static constexpr index_t GetSharedMemoryNumberOfByte() + { + // LDS allocation for A and B: be careful of alignment + constexpr auto a_block_desc_ak0_m_ak1 = GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1(); + constexpr auto b_block_desc_bk0_n_bk1 = GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1(); + + constexpr auto a_block_space_size_aligned = math::integer_least_multiple( + a_block_desc_ak0_m_ak1.GetElementSpaceSize(), max_lds_align); + constexpr auto b_block_space_size_aligned = math::integer_least_multiple( + b_block_desc_bk0_n_bk1.GetElementSpaceSize(), max_lds_align); + + return (a_block_space_size_aligned + b_block_space_size_aligned) * sizeof(ABDataType); + } + + __host__ static constexpr bool CheckValidity(const Problem& problem) + { + static_assert(is_known_at_compile_time>::value, + "Wrong! AK1 must be known at the time of compilation."); + static_assert(is_known_at_compile_time>::value, + "Wrong! BK1 must be known at the time of compilation."); + + static_assert( + MPerBlock % (MPerDpp * MDppPerWave) == 0, + "Invalid tuning parameters! MPerBlock must be divisible by MPerDpp * MDppPerWave."); + static_assert( + NPerBlock % (NPerDpp * NDppPerWave) == 0, + "Invalid tuning parameters! NPerBlock must be divisible by NPerDpp * NDppPerWave."); + + static_assert( + KPerBlock % AK1Value == 0 && KPerBlock % BK1Value == 0, + "Invalid tuning parameters! KPerBlock must be divisible by both AK1 and BK1."); + + static_assert(AK1Value % ABlockTransferDstScalarPerVector_K1 == 0, + "Invalid tuning parameters! AK1Value must be divisible by " + "ABlockTransferDstScalarPerVector_K1"); + + static_assert(BK1Value % BBlockTransferDstScalarPerVector_K1 == 0, + "Invalid tuning parameters! BK1Value must be divisible by " + "BBlockTransferDstScalarPerVector_K1"); + + if constexpr(!(GemmSpec == tensor_operation::device::GemmSpecialization::MPadding || + GemmSpec == tensor_operation::device::GemmSpecialization::MNPadding || + GemmSpec == tensor_operation::device::GemmSpecialization::MKPadding || + GemmSpec == tensor_operation::device::GemmSpecialization::MNKPadding)) + { + if(!(problem.M % MPerBlock == 0)) + { + return false; + } + } + + if constexpr(!(GemmSpec == tensor_operation::device::GemmSpecialization::NPadding || + GemmSpec == tensor_operation::device::GemmSpecialization::MNPadding || + GemmSpec == tensor_operation::device::GemmSpecialization::NKPadding || + GemmSpec == tensor_operation::device::GemmSpecialization::MNKPadding)) + { + if(!(problem.N % NPerBlock == 0)) + { + return false; + } + } + + if constexpr(is_same::value) + { + if(problem.K % ABlockTransferSrcScalarPerVector != 0) + { + return false; + } + } + else + { + if(problem.M % ABlockTransferSrcScalarPerVector != 0) + { + return false; + } + } + + if constexpr(is_same::value) + { + if(problem.N % BBlockTransferSrcScalarPerVector != 0) + { + return false; + } + } + else + { + if(problem.K % BBlockTransferSrcScalarPerVector != 0) + { + return false; + } + } + + if(problem.K % KPerBlock != 0) + { + return false; + } + + // check gridwise gemm pipeline + const auto num_k_loop = problem.K / KPerBlock; + if(!GridwiseGemmPipe::IsSupported(num_k_loop)) + { + return false; + } + + return true; + } + + __host__ static constexpr bool CalculateHasMainKBlockLoop(index_t K) + { + const auto num_loop = K / KPerBlock; + + return GridwiseGemmPipe::CalculateHasMainLoop(num_loop); + } + + template + __host__ __device__ static constexpr auto + MakeCGridDescriptor_M0_N0_M1_N1_M2_N2(const CGridDesc& c_grid_desc_m_n) + { + constexpr auto a_block_desc_ak0_m_ak1 = GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1(); + constexpr auto b_block_desc_bk0_n_bk1 = GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1(); + + constexpr index_t KPack = math::max( + math::lcm(AK1, BK1), DppSelector::selected_dpp.k_per_dpp); + + using BlockwiseGemm = + BlockwiseGemmDpp_ak0mak1_bk0nbk1_m0n0m1n1m2n2; + + return BlockwiseGemm::MakeCGridDescriptor_M0_N0_M1_N1_M2_N2(c_grid_desc_m_n); + } + + static constexpr auto matrix_padder = + ck::tensor_operation::device::MatrixPadder{ + MPerBlock, NPerBlock, KPerBlock}; + + __device__ static auto + MakeAGridDescriptor_AK0_M_AK1(index_t M, index_t K, index_t AK0, index_t StrideA) + { + const auto a_grid_desc_mraw_kraw = [&]() { + if constexpr(is_same::value) + { + return make_naive_tensor_descriptor(make_tuple(M, K), make_tuple(StrideA, I1)); + } + else if constexpr(is_same::value) + { + return make_naive_tensor_descriptor(make_tuple(M, K), make_tuple(I1, StrideA)); + } + }(); + + const auto a_grid_desc_m_k = matrix_padder.PadADescriptor_M_K(a_grid_desc_mraw_kraw); + return transform_tensor_descriptor( + a_grid_desc_m_k, + make_tuple(make_unmerge_transform(make_tuple(AK0, AK1Value)), + make_pass_through_transform(M)), + make_tuple(Sequence<1>{}, Sequence<0>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + } + + __device__ static auto + MakeBGridDescriptor_BK0_N_BK1(index_t K, index_t N, index_t BK0, index_t StrideB) + { + const auto b_grid_desc_nraw_kraw = [&]() { + if constexpr(is_same::value) + { + return make_naive_tensor_descriptor(make_tuple(N, K), make_tuple(I1, StrideB)); + } + else if constexpr(is_same::value) + { + return make_naive_tensor_descriptor(make_tuple(N, K), make_tuple(StrideB, I1)); + } + }(); + + const auto b_grid_desc_n_k = matrix_padder.PadBDescriptor_N_K(b_grid_desc_nraw_kraw); + return transform_tensor_descriptor( + b_grid_desc_n_k, + make_tuple(make_pass_through_transform(N), + make_unmerge_transform(make_tuple(BK0, BK1Value))), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<1>{}, Sequence<0, 2>{})); + } + + __device__ static auto MakeCGridDescriptor_M_N(index_t M, index_t N, index_t StrideC) + { + const auto c_grid_desc_mraw_nraw = [&]() { + if constexpr(is_same::value) + { + return make_naive_tensor_descriptor(make_tuple(M, N), make_tuple(StrideC, I1)); + } + else if constexpr(is_same::value) + { + return make_naive_tensor_descriptor(make_tuple(M, N), make_tuple(I1, StrideC)); + } + }(); + + return matrix_padder.PadCDescriptor_M_N(c_grid_desc_mraw_nraw); + } + + template + __device__ static void Run(const ABDataType* __restrict__ p_a_grid, + const ABDataType* __restrict__ p_b_grid, + CDataType* __restrict__ p_c_grid, + void* __restrict__ p_shared, + const AGridDesc_AK0_M_AK1& a_grid_desc_ak0_m_ak1, + const BGridDesc_BK0_N_BK1& b_grid_desc_bk0_n_bk1, + const CGridDesc_M_N& c_grid_desc_m_n) + { + const auto c_grid_desc_m0_n0_m1_n1_m2_n2 = + MakeCGridDescriptor_M0_N0_M1_N1_M2_N2(c_grid_desc_m_n); + + const auto a_grid_buf = make_dynamic_buffer( + p_a_grid, a_grid_desc_ak0_m_ak1.GetElementSpaceSize()); + const auto b_grid_buf = make_dynamic_buffer( + p_b_grid, b_grid_desc_bk0_n_bk1.GetElementSpaceSize()); + auto c_grid_buf = make_dynamic_buffer( + p_c_grid, c_grid_desc_m0_n0_m1_n1_m2_n2.GetElementSpaceSize()); + + const AElementwiseOperation a_element_op{}; + const BElementwiseOperation b_element_op{}; + const CElementwiseOperation c_element_op{}; + + const auto block_2_ctile_map = + Block2CTileMap{c_grid_desc_m_n.GetLength(I0), c_grid_desc_m_n.GetLength(I1)}; + + // divide block work by [M, N] + const auto block_work_idx = + block_2_ctile_map.CalculateBottomIndex(make_multi_index(get_block_1d_id())); + + if(!block_2_ctile_map.ValidCTileIndex( + block_work_idx, + make_tuple(c_grid_desc_m0_n0_m1_n1_m2_n2.GetLength(I0), + c_grid_desc_m0_n0_m1_n1_m2_n2.GetLength(I1)))) + { + return; + } + + // HACK: this force m/n_block_data_idx_on_grid into SGPR + const index_t m_block_data_idx_on_grid = + __builtin_amdgcn_readfirstlane(block_work_idx[I0] * MPerBlock); + const index_t n_block_data_idx_on_grid = + __builtin_amdgcn_readfirstlane(block_work_idx[I1] * NPerBlock); + + // A matrix in LDS memory, dst of blockwise copy + constexpr auto a_block_desc_ak0_m_ak1 = GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1(); + // B matrix in LDS memory, dst of blockwise copy + constexpr auto b_block_desc_bk0_n_bk1 = GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1(); + + auto a_blockwise_copy = + ThreadGroupTensorSliceTransfer_v4r1, + ABlockTransferThreadClusterLengths_K0_M_K1, + ABlockTransferThreadClusterArrangeOrder, + ABDataType, + ABDataType, + decltype(a_grid_desc_ak0_m_ak1), + decltype(a_block_desc_ak0_m_ak1), + ABlockTransferSrcAccessOrder, + Sequence<1, 0, 2>, + ABlockTransferSrcVectorDim, + 2, + ABlockTransferSrcScalarPerVector, + ABlockTransferDstScalarPerVector_K1, + 1, + 1, + AThreadTransferSrcResetCoordinateAfterRun, + true, + NumGemmKPrefetchStage>( + a_grid_desc_ak0_m_ak1, + make_multi_index(0, m_block_data_idx_on_grid, 0), + a_element_op, + a_block_desc_ak0_m_ak1, + make_multi_index(0, 0, 0), + ck::tensor_operation::element_wise::PassThrough{}); + + auto b_blockwise_copy = + ThreadGroupTensorSliceTransfer_v4r1, + BBlockTransferThreadClusterLengths_K0_N_K1, + BBlockTransferThreadClusterArrangeOrder, + ABDataType, + ABDataType, + decltype(b_grid_desc_bk0_n_bk1), + decltype(b_block_desc_bk0_n_bk1), + BBlockTransferSrcAccessOrder, + Sequence<1, 0, 2>, + BBlockTransferSrcVectorDim, + 2, + BBlockTransferSrcScalarPerVector, + BBlockTransferDstScalarPerVector_K1, + 1, + 1, + BThreadTransferSrcResetCoordinateAfterRun, + true, + NumGemmKPrefetchStage>( + b_grid_desc_bk0_n_bk1, + make_multi_index(0, n_block_data_idx_on_grid, 0), + b_element_op, + b_block_desc_bk0_n_bk1, + make_multi_index(0, 0, 0), + ck::tensor_operation::element_wise::PassThrough{}); + + // GEMM definition + // c_mtx += transpose(a_mtx) * b_mtx + // a_mtx[AK0PerBlock, MPerBlock] is in LDS + // b_mtx[BK0PerBlock, NPerBlock] is in LDS + // c_mtx[MPerBlock, NPerBlock] is distributed among threads, and saved in + // register + constexpr index_t KPack = math::max( + math::lcm(AK1, BK1), DppSelector::selected_dpp.k_per_dpp); + auto blockwise_gemm = + BlockwiseGemmDpp_ak0mak1_bk0nbk1_m0n0m1n1m2n2(); + + auto c_thread_buf = blockwise_gemm.GetCThreadBuffer(); + + // LDS allocation for A and B: be careful of alignment + constexpr auto a_block_space_size_aligned = math::integer_least_multiple( + a_block_desc_ak0_m_ak1.GetElementSpaceSize(), max_lds_align); + + auto a_block_buf = make_dynamic_buffer( + static_cast(p_shared), a_block_desc_ak0_m_ak1.GetElementSpaceSize()); + + auto b_block_buf = make_dynamic_buffer( + static_cast(p_shared) + a_block_space_size_aligned, + b_block_desc_bk0_n_bk1.GetElementSpaceSize()); + + constexpr auto a_block_slice_copy_step = make_multi_index(AK0PerBlock, 0, 0); + constexpr auto b_block_slice_copy_step = make_multi_index(BK0PerBlock, 0, 0); + + // gridwise GEMM pipeline + const auto AK0 = a_grid_desc_ak0_m_ak1.GetLength(I0); + // (AK0 / AK0PerBlock) is always equal to (BK0 / BK0PerBlock) + const index_t num_k_block_main_loop = __builtin_amdgcn_readfirstlane(AK0 / AK0PerBlock); + + GridwiseGemmPipe::template Run(a_grid_desc_ak0_m_ak1, + a_block_desc_ak0_m_ak1, + a_blockwise_copy, + a_grid_buf, + a_block_buf, + a_block_slice_copy_step, + b_grid_desc_bk0_n_bk1, + b_block_desc_bk0_n_bk1, + b_blockwise_copy, + b_grid_buf, + b_block_buf, + b_block_slice_copy_step, + blockwise_gemm, + c_thread_buf, + num_k_block_main_loop); + + // output: register to global memory + { + constexpr auto c_thread_desc_m0_n0_m1_n1_m2_n2 = + blockwise_gemm.GetCThreadDescriptor_M0_N0_M1_N1_M2_N2(); + + constexpr auto c_block_desc_m0_n0_m1_n1_m2_n2 = + blockwise_gemm.GetCBlockDescriptor_M0_N0_M1_N1_M2_N2(); + + constexpr auto M0 = c_block_desc_m0_n0_m1_n1_m2_n2.GetLength(I0); + constexpr auto N0 = c_block_desc_m0_n0_m1_n1_m2_n2.GetLength(I1); + constexpr auto M1 = c_block_desc_m0_n0_m1_n1_m2_n2.GetLength(I2); + constexpr auto N1 = c_block_desc_m0_n0_m1_n1_m2_n2.GetLength(I3); + constexpr auto M2 = c_block_desc_m0_n0_m1_n1_m2_n2.GetLength(I4); + constexpr auto N2 = c_block_desc_m0_n0_m1_n1_m2_n2.GetLength(I5); + + constexpr auto MPerThread = c_thread_desc_m0_n0_m1_n1_m2_n2.GetLength(I4); + constexpr auto NPerThread = c_thread_desc_m0_n0_m1_n1_m2_n2.GetLength(I5); + + // calculate origin of thread output tensor on global memory + // blockwise GEMM c matrix starting index + const auto c_thread_mtx_on_block = + blockwise_gemm.CalculateCThreadOriginDataIndex(I0, I0); + + const index_t m_thread_data_on_grid = + m_block_data_idx_on_grid + c_thread_mtx_on_block[I0]; + + const index_t n_thread_data_on_grid = + n_block_data_idx_on_grid + c_thread_mtx_on_block[I1]; + + const auto m_thread_data_on_grid_to_m0_m1_m2_adaptor = make_single_stage_tensor_adaptor( + make_tuple(make_merge_transform(make_tuple(M0, M1, M2))), + make_tuple(Sequence<0, 1, 2>{}), + make_tuple(Sequence<0>{})); + + const auto m_thread_data_on_grid_idx = + m_thread_data_on_grid_to_m0_m1_m2_adaptor.CalculateBottomIndex( + make_multi_index(m_thread_data_on_grid)); + + const auto n_thread_data_on_grid_to_n0_n1_n2_adaptor = make_single_stage_tensor_adaptor( + make_tuple(make_merge_transform(make_tuple(N0, N1, N2))), + make_tuple(Sequence<0, 1, 2>{}), + make_tuple(Sequence<0>{})); + + const auto n_thread_data_on_grid_idx = + n_thread_data_on_grid_to_n0_n1_n2_adaptor.CalculateBottomIndex( + make_multi_index(n_thread_data_on_grid)); + + auto c_thread_copy = + ThreadwiseTensorSliceTransfer_v1r3, + CThreadTransferSrcDstAccessOrder, + CThreadTransferSrcDstVectorDim, + CThreadTransferDstScalarPerVector, + CGlobalMemoryDataOperation, + 1, + true>{ + c_grid_desc_m0_n0_m1_n1_m2_n2, + make_multi_index(m_thread_data_on_grid_idx[I0], + n_thread_data_on_grid_idx[I0], + m_thread_data_on_grid_idx[I1], + n_thread_data_on_grid_idx[I1], + m_thread_data_on_grid_idx[I2], + n_thread_data_on_grid_idx[I2]), + c_element_op}; + + c_thread_copy.Run(c_thread_desc_m0_n0_m1_n1_m2_n2, + make_tuple(I0, I0, I0, I0, I0, I0), + c_thread_buf, + c_grid_desc_m0_n0_m1_n1_m2_n2, + c_grid_buf); + } + } +}; + +} // namespace ck diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_d_multiple_r_xdl_cshuffle.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_d_multiple_r_xdl_cshuffle.hpp index 578665ea85fb723565e597bbec26457ffb449606..d710fc18944ddb538ffb2d0385db862958e0542f 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_d_multiple_r_xdl_cshuffle.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_d_multiple_r_xdl_cshuffle.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #pragma once @@ -92,8 +92,8 @@ struct GridwiseGemmMultipleDMultipleR_k0mk1_k0nk1_mn_xdl_cshuffle_v1 using ThisThreadBlock = ThisThreadBlock; - using GridwiseGemmPipe = remove_cvref_t())>; + using GridwiseGemmPipe = remove_cvref_t< + decltype(GridwiseGemmPipeline_Selector())>; __host__ __device__ static constexpr auto GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1() { @@ -300,8 +300,9 @@ struct GridwiseGemmMultipleDMultipleR_k0mk1_k0nk1_mn_xdl_cshuffle_v1 remove_cvref_t; using DefaultBGridDesc_BK0_N_BK1 = remove_cvref_t; - using EGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock = remove_cvref_t; + using EGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock = + remove_cvref_t; // Support 2 dimension in the future. Not only M using RGridDescriptor_MBlock_MPerBlock = diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_d_wmma_cshuffle.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_d_wmma_cshuffle.hpp index d3f81566e93ab6b72f4edc7acfe097ef50e72270..98ade85a3089ccc5638fb45d4270488ebab59f3a 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_d_wmma_cshuffle.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_d_wmma_cshuffle.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #pragma once @@ -346,8 +346,8 @@ struct GridwiseGemmMultipleD_k0mk1_k0nk1_mn_wmma_cshuffle using ThisThreadBlock = ThisThreadBlock; - using GridwiseGemmPipe = remove_cvref_t())>; + using GridwiseGemmPipe = remove_cvref_t< + decltype(GridwiseGemmPipeline_Selector())>; __host__ __device__ static constexpr auto GetABlockDescriptor_K0PerBlock_MPerBlock_K1() { @@ -565,10 +565,12 @@ struct GridwiseGemmMultipleD_k0mk1_k0nk1_mn_wmma_cshuffle e_grid_desc_m_n); } - using DsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock = remove_cvref_t; - using EGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock = remove_cvref_t; + using DsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock = + remove_cvref_t; + using EGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock = + remove_cvref_t; using DefaultBlock2CTileMap = remove_cvref_t; using DsGridPointer = decltype(MakeDsGridPointer()); diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_d_xdl_cshuffle.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_d_xdl_cshuffle.hpp index ec1cc53991d128a8f281e01289a6a8625810971a..1d920fb44d183b144811053e3231c526d808858f 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_d_xdl_cshuffle.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_d_xdl_cshuffle.hpp @@ -15,6 +15,9 @@ #include "ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp" #include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" +#include "ck/tensor_operation/gpu/device/matrix_padder.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" + namespace ck { // GEMM: @@ -26,7 +29,9 @@ namespace ck { // E = cde_op(C, D0, D1, ...) // Assume: // D0, D1, ... and E have the same layout -template {}; static constexpr auto I1 = Number<1>{}; static constexpr auto I2 = Number<2>{}; @@ -89,18 +96,14 @@ struct GridwiseGemmMultipleD_xdl_cshuffle using ThisThreadBlock = ThisThreadBlock; - using GridwiseGemmPipe = remove_cvref_t())>; + using GridwiseGemmPipe = remove_cvref_t< + decltype(GridwiseGemmPipeline_Selector())>; - // denorm test fix, required to work around fp16 mfma issue - // we convert fp16->fp32->bf16 and execute bf16 mfma instruction - // when mfma if fixed, remove this section and update - // ABDataTypeAdjusted -> ABDataType throughout this file #if CK_WORKAROUND_DENORM_FIX - using ABDataTypeAdjusted = - conditional_t, ck::bhalf_t, ABDataType>; + using ComputeDataType = + conditional_t, ck::bhalf_t, ComputeDataType_>; #else - using ABDataTypeAdjusted = ABDataType; + using ComputeDataType = ComputeDataType_; #endif __host__ __device__ static constexpr auto GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1() @@ -170,7 +173,7 @@ struct GridwiseGemmMultipleD_xdl_cshuffle c_shuffle_block_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize(); return math::max((a_block_space_size_aligned + b_block_space_size_aligned) * - sizeof(ABDataType), + sizeof(ComputeDataType), c_block_size * sizeof(CShuffleDataType)); } @@ -265,13 +268,16 @@ struct GridwiseGemmMultipleD_xdl_cshuffle static_assert((MPerBlock % (MPerXdl * MXdlPerWave) == 0) && (NPerBlock % (NXdlPerWave * NPerXdl)) == 0, "Invalid tuning param!"); + static_assert(KPerBlock % AK1Value == 0 && KPerBlock % BK1Value == 0, + "KPerBlock must be divisible by AK1Value and BK1Value!"); - const auto M = a_grid_desc_m_k.GetLength(I0); - const auto N = b_grid_desc_n_k.GetLength(I0); - const auto K = a_grid_desc_m_k.GetLength(I1); + const auto M = a_grid_desc_m_k.GetLength(I0); + const auto N = b_grid_desc_n_k.GetLength(I0); + const auto AK = a_grid_desc_m_k.GetLength(I1); + const auto BK = b_grid_desc_n_k.GetLength(I1); // check consistency of desc - if(!(M == e_grid_desc_m_n.GetLength(I0) && N == e_grid_desc_m_n.GetLength(I1))) + if(!(M == e_grid_desc_m_n.GetLength(I0) && N == e_grid_desc_m_n.GetLength(I1) && AK == BK)) { return false; } @@ -289,13 +295,13 @@ struct GridwiseGemmMultipleD_xdl_cshuffle } // check tile size - if(!(M % MPerBlock == 0 && N % NPerBlock == 0 && K % KPerBlock == 0)) + if(!(M % MPerBlock == 0 && N % NPerBlock == 0 && AK % KPerBlock == 0)) { return false; } // check gridwise gemm pipeline - const auto num_k_loop = K / KPerBlock; + const auto num_k_loop = AK / KPerBlock; if(!GridwiseGemmPipe::IsSupported(num_k_loop)) { @@ -312,8 +318,8 @@ struct GridwiseGemmMultipleD_xdl_cshuffle // check tensor size: cannot be larger than 2GB each constexpr long_index_t TwoGB = (long_index_t{1} << 31); - if(!(a_grid_desc_m_k.GetElementSpaceSize() * sizeof(ABDataType) <= TwoGB && - b_grid_desc_n_k.GetElementSpaceSize() * sizeof(ABDataType) <= TwoGB && + if(!(a_grid_desc_m_k.GetElementSpaceSize() * sizeof(ADataType) <= TwoGB && + b_grid_desc_n_k.GetElementSpaceSize() * sizeof(BDataType) <= TwoGB && e_grid_desc_m_n.GetElementSpaceSize() * sizeof(EDataType) <= TwoGB)) { return false; @@ -331,14 +337,102 @@ struct GridwiseGemmMultipleD_xdl_cshuffle using DsGridPointer = decltype(MakeDsGridPointer()); + template + __host__ __device__ static auto + MakeAGridDescriptor_M_K(index_t MRaw, index_t KRaw, index_t StrideA) + { + constexpr auto matrix_padder = + ck::tensor_operation::device::MatrixPadder{ + MPerBlock, NPerBlock, KPerBlock}; + + const auto a_grid_desc_mraw_kraw = [&]() { + if constexpr(is_same_v) + { + return make_naive_tensor_descriptor(make_tuple(MRaw, KRaw), + make_tuple(StrideA, I1)); + } + else if constexpr(is_same_v) + { + return make_naive_tensor_descriptor(make_tuple(MRaw, KRaw), + make_tuple(I1, StrideA)); + } + }(); + + return matrix_padder.PadADescriptor_M_K(a_grid_desc_mraw_kraw); + } + + template + __host__ __device__ static auto + MakeBGridDescriptor_N_K(index_t KRaw, index_t NRaw, index_t StrideB) + { + constexpr auto matrix_padder = + ck::tensor_operation::device::MatrixPadder{ + MPerBlock, NPerBlock, KPerBlock}; + + const auto b_grid_desc_nraw_kraw = [&]() { + if constexpr(is_same::value) + { + return make_naive_tensor_descriptor(make_tuple(NRaw, KRaw), + make_tuple(I1, StrideB)); + } + else if constexpr(is_same::value) + { + return make_naive_tensor_descriptor(make_tuple(NRaw, KRaw), + make_tuple(StrideB, I1)); + } + }(); + + return matrix_padder.PadBDescriptor_N_K(b_grid_desc_nraw_kraw); + } + + template + __host__ __device__ static auto + MakeEGridDescriptor_M_N(index_t MRaw, index_t NRaw, index_t StrideE) + { + constexpr auto matrix_padder = + ck::tensor_operation::device::MatrixPadder{ + MPerBlock, NPerBlock, KPerBlock}; + const auto e_grid_desc_mraw_nraw = [&]() { + if constexpr(is_same::value) + { + return make_naive_tensor_descriptor(make_tuple(MRaw, NRaw), + make_tuple(StrideE, I1)); + } + else if constexpr(is_same::value) + { + return make_naive_tensor_descriptor(make_tuple(MRaw, NRaw), + make_tuple(I1, StrideE)); + } + }(); + + return matrix_padder.PadCDescriptor_M_N(e_grid_desc_mraw_nraw); + } + + template + __host__ __device__ static auto + MakeDsGridDescriptor_M_N(const std::array& MRaws, + const std::array& NRaws, + const std::array& DsStride) + { + return generate_tuple( + [&](auto i) { + using DLayout = remove_cvref_t>; + + return MakeEGridDescriptor_M_N(MRaws[i], NRaws[i], DsStride[i]); + }, + Number{}); + } + + __device__ __host__ static constexpr auto GetMPerBlock() { return MPerBlock; } + template - __device__ static void Run(const ABDataType* __restrict__ p_a_grid, - const ABDataType* __restrict__ p_b_grid, + __device__ static void Run(const ADataType* __restrict__ p_a_grid, + const BDataType* __restrict__ p_b_grid, DsGridPointer p_ds_grid, EDataType* __restrict__ p_e_grid, void* __restrict__ p_shared, @@ -407,8 +501,8 @@ struct GridwiseGemmMultipleD_xdl_cshuffle Sequence, ABlockTransferThreadClusterLengths_AK0_M_AK1, ABlockTransferThreadClusterArrangeOrder, - ABDataType, - ABDataTypeAdjusted, + ADataType, + ComputeDataType, decltype(a_grid_desc_ak0_m_ak1), decltype(a_block_desc_ak0_m_ak1), ABlockTransferSrcAccessOrder, @@ -438,8 +532,8 @@ struct GridwiseGemmMultipleD_xdl_cshuffle Sequence, BBlockTransferThreadClusterLengths_BK0_N_BK1, BBlockTransferThreadClusterArrangeOrder, - ABDataType, - ABDataTypeAdjusted, + BDataType, + ComputeDataType, decltype(b_grid_desc_bk0_n_bk1), decltype(b_block_desc_bk0_n_bk1), BBlockTransferSrcAccessOrder, @@ -469,11 +563,11 @@ struct GridwiseGemmMultipleD_xdl_cshuffle // sanity check constexpr index_t KPack = math::max(math::lcm(AK1, BK1), - MfmaSelector::selected_mfma.k_per_blk); + MfmaSelector::selected_mfma.k_per_blk); auto blockwise_gemm = BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_Selector< BlockSize, - ABDataTypeAdjusted, + ComputeDataType, AccDataType, decltype(a_block_desc_ak0_m_ak1), decltype(b_block_desc_bk0_n_bk1), @@ -491,11 +585,10 @@ struct GridwiseGemmMultipleD_xdl_cshuffle a_block_desc_ak0_m_ak1.GetElementSpaceSize(), max_lds_align); auto a_block_buf = make_dynamic_buffer( - static_cast(p_shared), - a_block_desc_ak0_m_ak1.GetElementSpaceSize()); + static_cast(p_shared), a_block_desc_ak0_m_ak1.GetElementSpaceSize()); auto b_block_buf = make_dynamic_buffer( - static_cast(p_shared) + a_block_space_size_aligned, + static_cast(p_shared) + a_block_space_size_aligned, b_block_desc_bk0_n_bk1.GetElementSpaceSize()); constexpr auto a_block_slice_copy_step = make_multi_index(KPerBlock / AK1, 0, 0); @@ -760,6 +853,85 @@ struct GridwiseGemmMultipleD_xdl_cshuffle }); } } + + template + __device__ static void Run(const void* __restrict__ p_a_grid_, + const void* __restrict__ p_b_grid_, + DsGridPointer p_ds_grid, + void* __restrict__ p_e_grid_, + void* __restrict__ p_shared, + const AElementwiseOperation& a_element_op, + const BElementwiseOperation& b_element_op, + const CDEElementwiseOperation& cde_element_op, + const index_t M, + const index_t N, + const index_t K, + const index_t StrideA, + const index_t StrideB, + const std::array StrideDs, + const index_t StrideE, + const Block2ETileMap& block_2_etile_map) + { + const auto p_a_grid = reinterpret_cast(p_a_grid_); + const auto p_b_grid = reinterpret_cast(p_b_grid_); + const auto p_e_grid = reinterpret_cast(p_e_grid_); + + // tensor descriptors for problem definiton + const auto a_grid_desc_m_k = MakeAGridDescriptor_M_K(M, K, StrideA); + const auto b_grid_desc_n_k = MakeBGridDescriptor_N_K(K, N, StrideB); + + using DsGridDesc_M_N = + remove_cvref_t({}, {}, {}))>; + + DsGridDesc_M_N ds_grid_desc_m_n; + + static_for<0, NumDTensor, 1>{}([&](auto j) { + using DLayout = remove_cvref_t>; + + ds_grid_desc_m_n(j) = MakeEGridDescriptor_M_N(M, N, StrideDs[j]); + }); + + const auto e_grid_desc_m_n = MakeEGridDescriptor_M_N(M, N, StrideE); + + // tensor descriptors for block/thread-wise copy + const auto a_grid_desc_ak0_m_ak1 = MakeDefaultAGridDescriptor_AK0_M_AK1(a_grid_desc_m_k); + + const auto b_grid_desc_bk0_n_bk1 = MakeDefaultBGridDescriptor_BK0_N_BK1(b_grid_desc_n_k); + + using DsGridDesc_MBlock_MPerBlock_NBlock_NPerBlock = + remove_cvref_t; + + DsGridDesc_MBlock_MPerBlock_NBlock_NPerBlock ds_grid_desc_mblock_mperblock_nblock_nperblock; + + static_for<0, NumDTensor, 1>{}([&](auto j) { + ds_grid_desc_mblock_mperblock_nblock_nperblock(j) = + MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(ds_grid_desc_m_n[j]); + }); + + const auto e_grid_desc_mblock_mperblock_nblock_nperblock = + MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(e_grid_desc_m_n); + + Run(p_a_grid, + p_b_grid, + p_ds_grid, + p_e_grid, + p_shared, + a_element_op, + b_element_op, + cde_element_op, + a_grid_desc_ak0_m_ak1, + b_grid_desc_bk0_n_bk1, + ds_grid_desc_mblock_mperblock_nblock_nperblock, + e_grid_desc_mblock_mperblock_nblock_nperblock, + block_2_etile_map); + } }; } // namespace ck diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_d_xdl_splitk_cshuffle.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_d_xdl_splitk_cshuffle.hpp new file mode 100644 index 0000000000000000000000000000000000000000..e22391293ee885139aae51fd94ecc754e7eea7b6 --- /dev/null +++ b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_d_xdl_splitk_cshuffle.hpp @@ -0,0 +1,1086 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck/utility/common_header.hpp" +#include "ck/tensor_description/multi_index_transform_helper.hpp" +#include "ck/tensor_description/tensor_descriptor.hpp" +#include "ck/tensor_description/tensor_descriptor_helper.hpp" +#include "ck/tensor_operation/gpu/grid/block_to_ctile_map.hpp" +#include "ck/tensor_operation/gpu/grid/gridwise_gemm_pipeline_selector.hpp" +#include "ck/tensor_operation/gpu/block/blockwise_gemm_xdlops.hpp" +#include "ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v4r1.hpp" +#include "ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v7.hpp" +#include "ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" + +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/matrix_padder.hpp" + +namespace ck { + +// GEMM: +// input : A[M, K] +// input : B[N, K] +// input : D0[M, N], D1[M, N], ... +// output : E[M, N] +// C = a_op(A) * b_op(B) +// E = cde_op(C, D0, D1, ...) +// Assume: +// D0, D1, ... and E have the same layout +template +struct GridwiseGemmMultipleD_xdl_splitk_cshuffle +{ + static constexpr index_t NumDTensor = DsDataType::Size(); + + using GemmSpecialization = ck::tensor_operation::device::GemmSpecialization; + + static constexpr auto I0 = Number<0>{}; + static constexpr auto I1 = Number<1>{}; + static constexpr auto I2 = Number<2>{}; + static constexpr auto I3 = Number<3>{}; + static constexpr auto I4 = Number<4>{}; + static constexpr auto I5 = Number<5>{}; + static constexpr auto I6 = Number<6>{}; + static constexpr auto I7 = Number<7>{}; + + // K1 should be Number<...> + static constexpr auto AK1 = Number{}; + static constexpr auto BK1 = Number{}; + static constexpr auto AK0PerBlock = Number{}; + static constexpr auto BK0PerBlock = Number{}; + + using ThisThreadBlock = ThisThreadBlock; + + using GridwiseGemmPipe = remove_cvref_t< + decltype(GridwiseGemmPipeline_Selector())>; + + // denorm test fix, required to work around fp16 mfma issue + // we convert fp16->fp32->bf16 and execute bf16 mfma instruction + // when mfma if fixed, remove this section and update + // ABDataTypeAdjusted -> ABDataType throughout this file +#if CK_WORKAROUND_DENORM_FIX + using ABDataTypeAdjusted = + conditional_t, ck::bhalf_t, ABDataType>; +#else + using ABDataTypeAdjusted = ABDataType; +#endif + + __host__ __device__ static constexpr auto GetABlockDescriptor_KBatch_AK0PerBlock_MPerBlock_AK1() + { + // A matrix in LDS memory, dst of blockwise copy + return make_naive_tensor_descriptor( + make_tuple(I1, AK0PerBlock, Number{}, AK1), + make_tuple(AK0PerBlock * Number{} * AK1, + Number{} * AK1, + AK1, + I1)); + } + + __host__ __device__ static constexpr auto GetBBlockDescriptor_KBatch_BK0PerBlock_NPerBlock_BK1() + { + // B matrix in LDS memory, dst of blockwise copy + return make_naive_tensor_descriptor( + make_tuple(I1, BK0PerBlock, Number{}, BK1), + make_tuple(BK0PerBlock * Number{} * BK1, + Number{} * BK1, + BK1, + I1)); + } + + __host__ __device__ static constexpr auto GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1() + { + // A matrix in LDS memory, dst of blockwise copy + return make_naive_tensor_descriptor( + make_tuple(AK0PerBlock, Number{}, AK1), + make_tuple(Number{} * AK1, AK1, I1)); + } + + __host__ __device__ static constexpr auto GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1() + { + // B matrix in LDS memory, dst of blockwise copy + return make_naive_tensor_descriptor( + make_tuple(BK0PerBlock, Number{}, BK1), + make_tuple(Number{} * BK1, BK1, I1)); + } + + __host__ __device__ static constexpr auto + GetCShuffleBlockDescriptor_MBlock_MPerBlock_NBlock_NPerBlock() + { + constexpr index_t MWave = MPerBlock / (MXdlPerWave * MPerXdl); + constexpr index_t NWave = NPerBlock / (NXdlPerWave * NPerXdl); + + constexpr auto c_shuffle_block_desc_mblock_mperblock_nblock_nperblock = + make_naive_tensor_descriptor_packed( + make_tuple(I1, + Number{}, + I1, + Number{})); + + return c_shuffle_block_desc_mblock_mperblock_nblock_nperblock; + } + + // ck::Tuple + static constexpr auto MakeDsGridPointer() + { + return generate_tuple( + [&](auto i) { + using DDataType = remove_cvref_t>; + + return static_cast(nullptr); + }, + Number{}); + } + + __host__ __device__ static constexpr index_t GetSharedMemoryNumberOfByte() + { + // LDS allocation for A and B: be careful of alignment + constexpr auto a_block_desc_ak0_m_ak1 = GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1(); + constexpr auto b_block_desc_bk0_n_bk1 = GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1(); + + // lds max alignment + constexpr auto max_lds_align = math::lcm(AK1, BK1); + + constexpr auto a_block_space_size_aligned = math::integer_least_multiple( + a_block_desc_ak0_m_ak1.GetElementSpaceSize(), max_lds_align); + + constexpr auto b_block_space_size_aligned = math::integer_least_multiple( + b_block_desc_bk0_n_bk1.GetElementSpaceSize(), max_lds_align); + + // LDS allocation for C shuffle in LDS + constexpr auto c_shuffle_block_desc_mblock_mperblock_nblock_nperblock = + GetCShuffleBlockDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(); + + constexpr auto c_block_size = + c_shuffle_block_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize(); + + return math::max((a_block_space_size_aligned + b_block_space_size_aligned) * + sizeof(ABDataType), + c_block_size * sizeof(CShuffleDataType)); + } + + __host__ __device__ static auto CalculateMPadded(index_t M) + { + return math::integer_least_multiple(M, MPerBlock); + } + + __host__ __device__ static auto CalculateNPadded(index_t N) + { + return math::integer_least_multiple(N, NPerBlock); + } + + __host__ __device__ static auto CalculateKPadded(index_t K, index_t K_Batch) + { + return math::integer_least_multiple(K, KPerBlock * K_Batch); + } + + template + __host__ __device__ static auto + MakeAGridDescriptor_KBatch_AK0_M_AK1(index_t M, index_t K, index_t StrideA, index_t KBatch) + { + const auto a_grid_desc_m_k = [&]() { + if constexpr(is_same::value) + { + return make_naive_tensor_descriptor(make_tuple(M, K), make_tuple(StrideA, I1)); + } + else if constexpr(is_same::value) + { + return make_naive_tensor_descriptor(make_tuple(M, K), make_tuple(I1, StrideA)); + } + }(); + + const auto MPad = CalculateMPadded(M); + const auto KPad = CalculateKPadded(K, KBatch); + + const auto a_grid_desc_m_kpad = transform_tensor_descriptor( + a_grid_desc_m_k, + make_tuple(make_pass_through_transform(M), make_right_pad_transform(K, KPad - K)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + + const auto AK0 = KPad / (KBatch * AK1); + + if constexpr(GemmSpec == tensor_operation::device::GemmSpecialization::MPadding || + GemmSpec == tensor_operation::device::GemmSpecialization::MNPadding || + GemmSpec == tensor_operation::device::GemmSpecialization::MKPadding || + GemmSpec == tensor_operation::device::GemmSpecialization::MNKPadding) + { + // const auto PadM = (MPerBlock - M % MPerBlock) % MPerBlock; + return transform_tensor_descriptor( + a_grid_desc_m_kpad, + make_tuple(make_unmerge_transform(make_tuple(KBatch, AK0, AK1)), + make_right_pad_transform(M, MPad - M)), + make_tuple(Sequence<1>{}, Sequence<0>{}), + make_tuple(Sequence<0, 1, 3>{}, Sequence<2>{})); + } + else + { + return transform_tensor_descriptor( + a_grid_desc_m_kpad, + make_tuple(make_unmerge_transform(make_tuple(KBatch, AK0, AK1)), + make_pass_through_transform(M)), + make_tuple(Sequence<1>{}, Sequence<0>{}), + make_tuple(Sequence<0, 1, 3>{}, Sequence<2>{})); + } + } + + template + __host__ __device__ static auto + MakeBGridDescriptor_KBatch_BK0_N_BK1(index_t K, index_t N, index_t StrideB, index_t KBatch) + { + const auto b_grid_desc_k_n = [&]() { + if constexpr(is_same::value) + { + return make_naive_tensor_descriptor(make_tuple(K, N), make_tuple(StrideB, I1)); + } + else if constexpr(is_same::value) + { + return make_naive_tensor_descriptor(make_tuple(K, N), make_tuple(I1, StrideB)); + } + }(); + + const auto NPad = CalculateNPadded(N); + const auto KPad = CalculateKPadded(K, KBatch); + + const auto b_grid_desc_kpad_n = transform_tensor_descriptor( + b_grid_desc_k_n, + make_tuple(make_right_pad_transform(K, KPad - K), make_pass_through_transform(N)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + + const auto BK0 = KPad / (KBatch * BK1); + + if constexpr(GemmSpec == tensor_operation::device::GemmSpecialization::NPadding || + GemmSpec == tensor_operation::device::GemmSpecialization::MNPadding || + GemmSpec == tensor_operation::device::GemmSpecialization::NKPadding || + GemmSpec == tensor_operation::device::GemmSpecialization::MNKPadding) + { + // const auto PadN = (NPerBlock - N % NPerBlock) % NPerBlock; + return transform_tensor_descriptor( + b_grid_desc_kpad_n, + make_tuple(make_unmerge_transform(make_tuple(KBatch, BK0, BK1)), + make_right_pad_transform(N, NPad - N)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0, 1, 3>{}, Sequence<2>{})); + } + else + { + return transform_tensor_descriptor( + b_grid_desc_kpad_n, + make_tuple(make_unmerge_transform(make_tuple(KBatch, BK0, BK1)), + make_pass_through_transform(N)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0, 1, 3>{}, Sequence<2>{})); + } + } + + // E desc for destination in blockwise copy + template + __host__ __device__ static constexpr auto + MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(const EGridDesc_M_N& e_grid_desc_m_n) + { + const auto M = e_grid_desc_m_n.GetLength(I0); + const auto N = e_grid_desc_m_n.GetLength(I1); + + const auto MBlock = M / MPerBlock; + const auto NBlock = N / NPerBlock; + + const auto e_grid_desc_mblock_mperblock_nblock_nperblock = transform_tensor_descriptor( + e_grid_desc_m_n, + make_tuple(make_unmerge_transform(make_tuple(MBlock, Number{})), + make_unmerge_transform(make_tuple(NBlock, Number{}))), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0, 1>{}, Sequence<2, 3>{})); + + return e_grid_desc_mblock_mperblock_nblock_nperblock; + } + + // Ds desc for source in blockwise copy + template + __host__ __device__ static constexpr auto + MakeDsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(const DsGridDesc_M_N& ds_grid_desc_m_n) + { + return generate_tuple( + [&](auto i) { + return MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(ds_grid_desc_m_n[i]); + }, + Number{}); + } + + // return block_id to E matrix tile idx (m0, n0) mapping + template + __host__ __device__ static constexpr auto + MakeDefaultBlock2ETileMap(const EGridDesc_M_N& e_grid_desc_m_n) + { + return BlockToCTileMap_M00_N0_M01Adapt( + e_grid_desc_m_n); + } + + template + __host__ __device__ static constexpr bool + CheckValidity(const index_t M, + const index_t N, + const index_t K, + const index_t StrideA, + const index_t StrideB, + const std::array StrideDs, + const index_t StrideE, + const index_t KBatch) + { + const auto a_grid_desc_kbatch_ak0_m_ak1 = + MakeAGridDescriptor_KBatch_AK0_M_AK1(M, K, StrideA, KBatch); + const auto b_grid_desc_kbatch_bk0_n_bk1 = + MakeBGridDescriptor_KBatch_BK0_N_BK1(K, N, StrideB, KBatch); + + ignore = StrideDs; + + const auto e_grid_desc_m_n = MakeEGridDescriptor_M_N(M, N, StrideE); + +#if 0 + // check tile size + if(!(M % MPerBlock == 0 && N % NPerBlock == 0 && K % KPerBlock == 0)) + { + return false; + } +#endif + + // check gridwise gemm pipeline + const auto num_k_loop = K / KPerBlock; + + if(!GridwiseGemmPipe::IsSupported(num_k_loop)) + { + return false; + } + + // TODO: also check validity of all components (blockwise-copy, threadwise-copy, etc) + // check tensor size: cannot be larger than 2GB each + constexpr long_index_t TwoGB = (long_index_t{1} << 31); + + if(!(a_grid_desc_kbatch_ak0_m_ak1.GetElementSpaceSize() * sizeof(ABDataType) <= TwoGB && + b_grid_desc_kbatch_bk0_n_bk1.GetElementSpaceSize() * sizeof(ABDataType) <= TwoGB && + e_grid_desc_m_n.GetElementSpaceSize() * sizeof(EDataType) <= TwoGB)) + { + return false; + } + + return true; + } + + __host__ __device__ static constexpr bool CalculateHasMainKBlockLoop(index_t K) + { + const index_t num_loop = K / KPerBlock; + + return GridwiseGemmPipe::CalculateHasMainLoop(num_loop); + } + + using DsGridPointer = decltype(MakeDsGridPointer()); + + template + __host__ __device__ static auto + MakeEGridDescriptor_M_N(index_t MRaw, index_t NRaw, index_t StrideE) + { + constexpr auto matrix_padder = + ck::tensor_operation::device::MatrixPadder{ + MPerBlock, NPerBlock, KPerBlock}; + const auto e_grid_desc_mraw_nraw = [&]() { + if constexpr(is_same::value) + { + return make_naive_tensor_descriptor(make_tuple(MRaw, NRaw), + make_tuple(StrideE, I1)); + } + else if constexpr(is_same::value) + { + return make_naive_tensor_descriptor(make_tuple(MRaw, NRaw), + make_tuple(I1, StrideE)); + } + }(); + + return matrix_padder.PadCDescriptor_M_N(e_grid_desc_mraw_nraw); + } + + template + __host__ __device__ static auto + MakeDsGridDescriptor_M_N(const std::array& MRaws, + const std::array& NRaws, + const std::array& DsStride) + { + return generate_tuple( + [&](auto i) { + using DLayout = remove_cvref_t>; + + return MakeEGridDescriptor_M_N(MRaws[i], NRaws[i], DsStride[i]); + }, + Number{}); + } + + __device__ __host__ static constexpr auto GetMPerBlock() { return MPerBlock; } + + template + __device__ static void Run(const ABDataType* __restrict__ p_a_grid, + const ABDataType* __restrict__ p_b_grid, + DsGridPointer p_ds_grid, + EDataType* __restrict__ p_e_grid, + void* __restrict__ p_shared, + uint32_t* barrier_count_finished, + const index_t KBatch, + const AElementwiseOperation& a_element_op, + const BElementwiseOperation& b_element_op, + const CDEElementwiseOperation_& cde_element_op, + const AGridDesc_KBatch_AK0_M_AK1& a_grid_desc_kbatch_ak0_m_ak1, + const BGridDesc_KBatch_BK0_N_BK1& b_grid_desc_kbatch_bk0_n_bk1, + const DsGridDesc_MBlock_MPerBlock_NBlock_NPerBlock& + ds_grid_desc_mblock_mperblock_nblock_nperblock, + const EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock& + e_grid_desc_mblock_mperblock_nblock_nperblock, + const Block2ETileMap& block_2_etile_map) + { + const auto a_grid_buf = make_dynamic_buffer( + p_a_grid, a_grid_desc_kbatch_ak0_m_ak1.GetElementSpaceSize()); + + const auto b_grid_buf = make_dynamic_buffer( + p_b_grid, b_grid_desc_kbatch_bk0_n_bk1.GetElementSpaceSize()); + + const auto ds_grid_buf = generate_tuple( + [&](auto i) { + return make_dynamic_buffer( + p_ds_grid[i], + ds_grid_desc_mblock_mperblock_nblock_nperblock[i].GetElementSpaceSize()); + }, + Number{}); + + auto e_grid_buf = make_dynamic_buffer( + p_e_grid, e_grid_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize()); + + // divide block work by [M, N] + const auto block_work_idx = + block_2_etile_map.CalculateBottomIndex(make_multi_index(get_block_1d_id())); + + // HACK: this force m/n_block_data_idx_on_grid into SGPR + const index_t kbatch_id = __builtin_amdgcn_readfirstlane(block_work_idx[I0]); + + const index_t m_block_data_idx_on_grid = + __builtin_amdgcn_readfirstlane(block_work_idx[I1] * MPerBlock); + + const index_t n_block_data_idx_on_grid = + __builtin_amdgcn_readfirstlane(block_work_idx[I2] * NPerBlock); + + // lds max alignment + constexpr auto max_lds_align = math::lcm(AK1, BK1); + + // A matrix in LDS memory, dst of blockwise copy + constexpr auto a_block_desc_kbatch_ak0_m_ak1 = + GetABlockDescriptor_KBatch_AK0PerBlock_MPerBlock_AK1(); + + // B matrix in LDS memory, dst of blockwise copy + constexpr auto b_block_desc_kbatch_bk0_n_bk1 = + GetBBlockDescriptor_KBatch_BK0PerBlock_NPerBlock_BK1(); + + // A matrix blockwise copy + auto a_blockwise_copy = + ThreadGroupTensorSliceTransfer_v4r1, + ABlockTransferThreadClusterLengths_KBatch_AK0_M_AK1, + ABlockTransferThreadClusterArrangeOrder, + ABDataType, + ABDataTypeAdjusted, + decltype(a_grid_desc_kbatch_ak0_m_ak1), + decltype(a_block_desc_kbatch_ak0_m_ak1), + ABlockTransferSrcAccessOrder, + Sequence<2, 0, 1, 3>, + ABlockTransferSrcVectorDim, + 3, + ABlockTransferSrcScalarPerVector, + ABlockTransferDstScalarPerVector_AK1, + 1, + 1, + AThreadTransferSrcResetCoordinateAfterRun, + true, + NumGemmKPrefetchStage>( + a_grid_desc_kbatch_ak0_m_ak1, + make_multi_index(kbatch_id, 0, m_block_data_idx_on_grid, 0), + a_element_op, + a_block_desc_kbatch_ak0_m_ak1, + make_multi_index(0, 0, 0, 0), + ck::tensor_operation::element_wise::PassThrough{}); + + // B matrix blockwise copy + auto b_blockwise_copy = + ThreadGroupTensorSliceTransfer_v4r1, + BBlockTransferThreadClusterLengths_KBatch_BK0_N_BK1, + BBlockTransferThreadClusterArrangeOrder, + ABDataType, + ABDataTypeAdjusted, + decltype(b_grid_desc_kbatch_bk0_n_bk1), + decltype(b_block_desc_kbatch_bk0_n_bk1), + BBlockTransferSrcAccessOrder, + Sequence<2, 0, 1, 3>, + BBlockTransferSrcVectorDim, + 3, + BBlockTransferSrcScalarPerVector, + BBlockTransferDstScalarPerVector_BK1, + 1, + 1, + BThreadTransferSrcResetCoordinateAfterRun, + true, + NumGemmKPrefetchStage>( + b_grid_desc_kbatch_bk0_n_bk1, + make_multi_index(kbatch_id, 0, n_block_data_idx_on_grid, 0), + b_element_op, + b_block_desc_kbatch_bk0_n_bk1, + make_multi_index(0, 0, 0, 0), + ck::tensor_operation::element_wise::PassThrough{}); + + // A matrix in LDS memory, dst of blockwise copy + constexpr auto a_block_desc_ak0_m_ak1 = GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1(); + + // B matrix in LDS memory, dst of blockwise copy + constexpr auto b_block_desc_bk0_n_bk1 = GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1(); + + // GEMM definition + // c_mtx += transpose(a_mtx) * b_mtx + // a_mtx[K0PerBlock, MPerBlock] is in LDS + // b_mtx[K0PerBlock, NPerBlock] is in LDS + // c_mtx[MPerBlock, NPerBlock] is distributed among threads, and saved in + // register + // sanity check + constexpr index_t KPack = + math::max(math::lcm(AK1, BK1), + MfmaSelector::selected_mfma.k_per_blk); + + auto blockwise_gemm = BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_Selector< + BlockSize, + ABDataTypeAdjusted, + AccDataType, + decltype(a_block_desc_ak0_m_ak1), + decltype(b_block_desc_bk0_n_bk1), + MPerXdl, + NPerXdl, + MXdlPerWave, + NXdlPerWave, + KPack, + LoopSched>(); + +#if 1 + if(block_work_idx[I0] == 0) + { + const index_t nThreadSize = CDEShuffleBlockTransferScalarPerVector_NPerBlock; + const index_t numNThreads = NPerBlock / nThreadSize; + const index_t numMThreads = BlockSize / numNThreads; + const index_t mThreadSize = MPerBlock / numMThreads; + + const index_t m_tid = get_thread_local_1d_id() / numNThreads; + const index_t n_tid = get_thread_local_1d_id() % numNThreads; + + auto c_thread_desc_mblock_mperblock_nblock_nperblock = + make_naive_tensor_descriptor_packed( + make_tuple(I1, Number{}, I1, Number{})); + + StaticBuffer + e_thread_zero_buf; + + auto c_thread_copy = ThreadwiseTensorSliceTransfer_v1r3< + EDataType, + EDataType, + decltype(c_thread_desc_mblock_mperblock_nblock_nperblock), + decltype(e_grid_desc_mblock_mperblock_nblock_nperblock), + ck::tensor_operation::element_wise::PassThrough, + Sequence<1, mThreadSize, 1, nThreadSize>, + Sequence<0, 1, 2, 3>, + 3, + CDEShuffleBlockTransferScalarPerVector_NPerBlock, + InMemoryDataOperationEnum::Set, + 1, + true>{e_grid_desc_mblock_mperblock_nblock_nperblock, + make_multi_index(block_work_idx[I1], + m_tid * mThreadSize, + block_work_idx[I2], + n_tid * nThreadSize), + ck::tensor_operation::element_wise::PassThrough{}}; + + c_thread_copy.Run(c_thread_desc_mblock_mperblock_nblock_nperblock, + make_tuple(I0, I0, I0, I0), + e_thread_zero_buf, + e_grid_desc_mblock_mperblock_nblock_nperblock, + e_grid_buf); + + __syncthreads(); + + if(threadIdx.x == 0) + { + atomicAdd(barrier_count_finished, 1); + } + } +#endif + + auto c_thread_buf = blockwise_gemm.GetCThreadBuffer(); + + // LDS allocation for A and B: be careful of alignment + constexpr auto a_block_space_size_aligned = math::integer_least_multiple( + a_block_desc_ak0_m_ak1.GetElementSpaceSize(), max_lds_align); + + auto a_block_buf = make_dynamic_buffer( + static_cast(p_shared), + a_block_desc_ak0_m_ak1.GetElementSpaceSize()); + + auto b_block_buf = make_dynamic_buffer( + static_cast(p_shared) + a_block_space_size_aligned, + b_block_desc_bk0_n_bk1.GetElementSpaceSize()); + + constexpr auto a_block_slice_copy_step = make_multi_index(0, KPerBlock / AK1, 0, 0); + constexpr auto b_block_slice_copy_step = make_multi_index(0, KPerBlock / BK1, 0, 0); + + // gridwise GEMM pipeline + const auto gridwise_gemm_pipeline = + GridwiseGemmPipeline_Selector(); + + const index_t num_k_block_main_loop = + __builtin_amdgcn_readfirstlane((a_grid_desc_kbatch_ak0_m_ak1.GetLength(I1) * + a_grid_desc_kbatch_ak0_m_ak1.GetLength(I3)) / + KPerBlock); + + gridwise_gemm_pipeline.template Run(a_grid_desc_kbatch_ak0_m_ak1, + a_block_desc_kbatch_ak0_m_ak1, + a_blockwise_copy, + a_grid_buf, + a_block_buf, + a_block_slice_copy_step, + b_grid_desc_kbatch_bk0_n_bk1, + b_block_desc_kbatch_bk0_n_bk1, + b_blockwise_copy, + b_grid_buf, + b_block_buf, + b_block_slice_copy_step, + blockwise_gemm, + c_thread_buf, + num_k_block_main_loop); + + // shuffle C and write out + { + if(threadIdx.x == 0) + { + while(__atomic_load_n(barrier_count_finished, __ATOMIC_RELAXED) == 0) {} + } + + __syncthreads(); + + static_assert(MXdlPerWave % CShuffleMXdlPerWavePerShuffle == 0 && + NXdlPerWave % CShuffleNXdlPerWavePerShuffle == 0, + "wrong!"); + + constexpr index_t MWave = MPerBlock / (MXdlPerWave * MPerXdl); + constexpr index_t NWave = NPerBlock / (NXdlPerWave * NPerXdl); + + // TODO: hacky, fix it! + constexpr auto c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2 = + blockwise_gemm.GetCThreadDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(); + + // TODO: hacky, fix it! + // c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp is only used to get lengths + constexpr auto c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp = + blockwise_gemm.GetCBlockDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(); + + constexpr auto M0 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I0); + constexpr auto N0 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I1); + constexpr auto M1 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I2); + constexpr auto N1 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I3); + constexpr auto M2 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I4); + constexpr auto M3 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I5); + constexpr auto M4 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I6); + constexpr auto N2 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I7); + + constexpr auto c_shuffle_block_desc_mblock_mperblock_nblock_nperblock = + GetCShuffleBlockDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(); + + auto c_shuffle_block_buf = make_dynamic_buffer( + static_cast(p_shared), + c_shuffle_block_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize()); + + constexpr auto c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2 = transform_tensor_descriptor( + c_shuffle_block_desc_mblock_mperblock_nblock_nperblock, + make_tuple( + make_freeze_transform(I0), + make_unmerge_transform(make_tuple( + Number{}, // M0 (MXdlPerWave) per shuffle + M1, // M1 = MWave + M2, // M2 * M3 * M4 = MPerXdl + M3, + M4)), + make_freeze_transform(I0), + make_unmerge_transform(make_tuple( + Number{}, // N0 (NXdlPerWave) per shuffle + N1, // N1 = NWave + N2))), // N2 = NPerXdl + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), + make_tuple( + Sequence<>{}, Sequence<0, 2, 4, 5, 6>{}, Sequence<>{}, Sequence<1, 3, 7>{})); + + // calculate origin of thread output tensor on global memory + // blockwise GEMM c matrix starting index + const auto c_thread_mtx_on_block = + blockwise_gemm.CalculateCThreadOriginDataIndex(I0, I0, I0, I0); + + const index_t m_thread_data_on_block = c_thread_mtx_on_block[I0]; + const index_t n_thread_data_on_block = c_thread_mtx_on_block[I1]; + + const auto m_thread_data_on_block_to_m0_m1_m2_m3_m4_adaptor = + make_single_stage_tensor_adaptor( + make_tuple(make_merge_transform(make_tuple(M0, M1, M2, M3, M4))), + make_tuple(Sequence<0, 1, 2, 3, 4>{}), + make_tuple(Sequence<0>{})); + + const auto m_thread_data_on_block_idx = + m_thread_data_on_block_to_m0_m1_m2_m3_m4_adaptor.CalculateBottomIndex( + make_multi_index(m_thread_data_on_block)); + + const auto n_thread_data_on_block_to_n0_n1_n2_adaptor = + make_single_stage_tensor_adaptor( + make_tuple(make_merge_transform(make_tuple(N0, N1, N2))), + make_tuple(Sequence<0, 1, 2>{}), + make_tuple(Sequence<0>{})); + + const auto n_thread_data_on_block_idx = + n_thread_data_on_block_to_n0_n1_n2_adaptor.CalculateBottomIndex( + make_multi_index(n_thread_data_on_block)); + + // shuffle: threadwise copy C from VGPR to LDS + auto c_thread_copy_vgpr_to_lds = + ThreadwiseTensorSliceTransfer_v1r3, + Sequence<0, 1, 2, 3, 4, 5, 6, 7>, + 7, + 1, + InMemoryDataOperationEnum::Set, + 1, + true>{ + c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2, + make_multi_index(0, + 0, + m_thread_data_on_block_idx[I1], + n_thread_data_on_block_idx[I1], + m_thread_data_on_block_idx[I2], + m_thread_data_on_block_idx[I3], + m_thread_data_on_block_idx[I4], + n_thread_data_on_block_idx[I2]), + ck::tensor_operation::element_wise::PassThrough{}}; + + // tuple of reference to C/Ds tensor descriptors + const auto c_ds_desc_refs = concat_tuple_of_reference( + tie(c_shuffle_block_desc_mblock_mperblock_nblock_nperblock), + generate_tie( + [&](auto i) -> const auto& // return type should be reference + { return ds_grid_desc_mblock_mperblock_nblock_nperblock[i]; }, + Number{})); + + // tuple of reference to C/Ds tensor descriptors + const auto c_ds_buf_refs = concat_tuple_of_reference( + tie(c_shuffle_block_buf), + generate_tie( + [&](auto i) -> const auto& // return type should be reference + { return ds_grid_buf[i]; }, + Number{})); + + // tuple of starting index of C/Ds blockwise copy + const auto idx_c_ds_block_begin = container_concat( + make_tuple(make_multi_index(0, 0, 0, 0)), + generate_tuple( + [&](auto) { + return make_multi_index(block_work_idx[I1], 0, block_work_idx[I2], 0); + }, + Number{})); + + // space filling curve for threadwise C in VGPR before shuffle + constexpr auto sfc_c_vgpr = + SpaceFillingCurve, + Sequence<0, 1, 2, 3, 4, 5, 6, 7>, + Sequence>{}; + + // space filling curve for shuffled blockwise C/D/E + constexpr auto sfc_cde_block = + SpaceFillingCurve, + Sequence<0, 2, 1, 3>, + Sequence<1, + CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl, + 1, + CShuffleNXdlPerWavePerShuffle * NWave * NPerXdl>>{}; + + constexpr index_t num_access = sfc_c_vgpr.GetNumOfAccess(); + + static_assert(num_access == sfc_cde_block.GetNumOfAccess(), "wrong!"); + + // blockwise copy C/D/E between LDS and global + auto cde_block_copy_lds_and_global = ThreadGroupTensorSliceTransfer_v7< + ThisThreadBlock, + decltype(container_concat(make_tuple(CShuffleDataType{}), DsDataType_{})), + Tuple, + decltype(c_ds_desc_refs), + decltype(tie(e_grid_desc_mblock_mperblock_nblock_nperblock)), + CDEElementwiseOperation_, + Sequence(EGlobalMemoryDataOperation)>, // FIXME: make + // Sequence support + // arbitray type + Sequence<1, + CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl, + 1, + CShuffleNXdlPerWavePerShuffle * NWave * NPerXdl>, // BlockSliceLengths, + CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, + Sequence<0, 1, 2, 3>, // typename ThreadClusterArrangeOrder, + Sequence<0, 1, 2, 3>, // typename DimAccessOrder, + 3, // index_t VectorDim, + CDEShuffleBlockTransferScalarPerVector_NPerBlock, + sequence_merge_t< + Sequence, + uniform_sequence_gen_t>, // ThreadTransferSrcResetCoordinateAfterRunFlags + Sequence> // ThreadTransferDstResetCoordinateAfterRunFlags + {c_ds_desc_refs, + idx_c_ds_block_begin, + tie(e_grid_desc_mblock_mperblock_nblock_nperblock), + make_tuple(make_multi_index(block_work_idx[I1], 0, block_work_idx[I2], 0)), + cde_element_op}; + + static_for<0, num_access, 1>{}([&](auto access_id) { + // make sure it's safe to write to LDS + block_sync_lds(); + + // each thread write its data from VGPR to LDS + c_thread_copy_vgpr_to_lds.Run(c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2, + sfc_c_vgpr.GetIndexTupleOfNumber(access_id), + c_thread_buf, + c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2, + c_shuffle_block_buf); + + // make sure it's safe to read from LDS + block_sync_lds(); + + // each block copy its data from LDS to global + cde_block_copy_lds_and_global.Run( + c_ds_desc_refs, + c_ds_buf_refs, + tie(e_grid_desc_mblock_mperblock_nblock_nperblock), + tie(e_grid_buf)); + + if constexpr(access_id < num_access - 1) + { + constexpr auto cde_lds_and_global_step = + sfc_cde_block.GetForwardStep(access_id); + + // move on Ds + static_for<0, NumDTensor_, 1>{}([&](auto i) { + cde_block_copy_lds_and_global.MoveSrcSliceWindow( + c_ds_desc_refs, i + I1, cde_lds_and_global_step); + }); + + // move on E + cde_block_copy_lds_and_global.MoveDstSliceWindow( + tie(e_grid_desc_mblock_mperblock_nblock_nperblock), + I0, + cde_lds_and_global_step); + } + }); + + if(threadIdx.x == 0) + { + index_t k_id_finished_t = atomicAdd(barrier_count_finished, 1); + + if(k_id_finished_t == KBatch) + { + *barrier_count_finished = 0; + } + } + } + } + + template + __device__ static void Run(const void* __restrict__ p_a_grid_, + const void* __restrict__ p_b_grid_, + DsGridPointer p_ds_grid, + void* __restrict__ p_e_grid_, + void* __restrict__ p_shared, + uint32_t* barrier_count_finished, + const AElementwiseOperation& a_element_op, + const BElementwiseOperation& b_element_op, + const CDEElementwiseOperation& cde_element_op, + const index_t M, + const index_t N, + const index_t K, + const index_t StrideA, + const index_t StrideB, + const std::array StrideDs, + const index_t StrideE, + const index_t KBatch, + const Block2ETileMap& block_2_etile_map) + { + const auto p_a_grid = reinterpret_cast(p_a_grid_); + const auto p_b_grid = reinterpret_cast(p_b_grid_); + const auto p_e_grid = reinterpret_cast(p_e_grid_); + + using DsGridDesc_M_N = + remove_cvref_t({}, {}, {}))>; + + DsGridDesc_M_N ds_grid_desc_m_n; + + static_for<0, NumDTensor, 1>{}([&](auto j) { + using DLayout = remove_cvref_t>; + + ds_grid_desc_m_n(j) = MakeEGridDescriptor_M_N(M, N, StrideDs[j]); + }); + + const auto e_grid_desc_m_n = MakeEGridDescriptor_M_N(M, N, StrideE); + + // tensor descriptors for block/thread-wise copy + const auto a_grid_desc_kbatch_ak0_m_ak1 = + MakeAGridDescriptor_KBatch_AK0_M_AK1(M, K, StrideA, KBatch); + + const auto b_grid_desc_kbatch_bk0_n_bk1 = + MakeBGridDescriptor_KBatch_BK0_N_BK1(K, N, StrideB, KBatch); + + using DsGridDesc_MBlock_MPerBlock_NBlock_NPerBlock = + remove_cvref_t; + + DsGridDesc_MBlock_MPerBlock_NBlock_NPerBlock ds_grid_desc_mblock_mperblock_nblock_nperblock; + + static_for<0, NumDTensor, 1>{}([&](auto j) { + ds_grid_desc_mblock_mperblock_nblock_nperblock(j) = + MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(ds_grid_desc_m_n[j]); + }); + + const auto e_grid_desc_mblock_mperblock_nblock_nperblock = + MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(e_grid_desc_m_n); + + const auto block_work_idx = + block_2_etile_map.CalculateBottomIndex(make_multi_index(get_block_1d_id())); + + const index_t kbatch_id = __builtin_amdgcn_readfirstlane(block_work_idx[I0]); + + if(kbatch_id == KBatch - 1) + { + Run( + p_a_grid, + p_b_grid, + p_ds_grid, + p_e_grid, + p_shared, + barrier_count_finished, + KBatch, + a_element_op, + b_element_op, + cde_element_op, + a_grid_desc_kbatch_ak0_m_ak1, + b_grid_desc_kbatch_bk0_n_bk1, + ds_grid_desc_mblock_mperblock_nblock_nperblock, + e_grid_desc_mblock_mperblock_nblock_nperblock, + block_2_etile_map); + } + else + { + Run>( + p_a_grid, + p_b_grid, + p_ds_grid, + p_e_grid, + p_shared, + barrier_count_finished, + KBatch, + a_element_op, + b_element_op, + ck::tensor_operation::element_wise::PassThrough{}, + a_grid_desc_kbatch_ak0_m_ak1, + b_grid_desc_kbatch_bk0_n_bk1, + ds_grid_desc_mblock_mperblock_nblock_nperblock, + e_grid_desc_mblock_mperblock_nblock_nperblock, + block_2_etile_map); + } + } +}; + +} // namespace ck diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_pipeline_selector.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_pipeline_selector.hpp index 98331d854490abe4e7f478d817104218cb49c28f..f760feb2ed22015567ef66648d0d1afa81bdff9d 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_pipeline_selector.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_pipeline_selector.hpp @@ -1,8 +1,10 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #pragma once +#include + #include "ck/tensor_operation/gpu/grid/gridwise_gemm_pipeline_v1.hpp" #include "ck/tensor_operation/gpu/grid/gridwise_gemm_pipeline_v2.hpp" diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_pipeline_v1.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_pipeline_v1.hpp index d1209636de93a13cbe7f851575b4abaedd99af72..754a3e89c9327792fb370e30fbacdee2b9758163 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_pipeline_v1.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_pipeline_v1.hpp @@ -4,7 +4,8 @@ #pragma once #include "ck/utility/common_header.hpp" -#include "ck/tensor_operation/gpu/block/blockwise_gemm_xdlops.hpp" +#include "ck/utility/loop_scheduler.hpp" +#include "ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp" namespace ck { diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_pipeline_v2.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_pipeline_v2.hpp index 3281b910d3d8782fc18f91ee77ce42aa44fc5ba9..d3d7d5af8577d1c2ed65fa817f708379c6089a25 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_pipeline_v2.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_pipeline_v2.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #pragma once @@ -79,6 +79,10 @@ struct GridwiseGemmPipeline_v2 do { +#if CK_EXPERIMENTAL_PIPELINE_V2_IGLP_OPT + __builtin_amdgcn_iglp_opt(CK_EXPERIMENTAL_PIPELINE_V2_IGLP_OPT); +#endif + block_sync_lds(); // GEMM i diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_pipeline_v3.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_pipeline_v3.hpp new file mode 100644 index 0000000000000000000000000000000000000000..ced62241cd29b54ba8b6a69c87896268b9b49a7e --- /dev/null +++ b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_pipeline_v3.hpp @@ -0,0 +1,89 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck/utility/common_header.hpp" + +namespace ck { + +struct GridwiseGemmPipeline_v3 +{ + __host__ __device__ static constexpr bool IsSupported(index_t) + { + // TODO: improve applicability + return true; + } + + template + __device__ static void Run(const AGridDesc& a_grid_desc, + const ABlockDesc& a_block_desc, + ABlockTransfer& a_blockwise_copy, + const AGridBuffer& a_grid_buf, + ABlockBuffer& a_block_buf, + const ABlockTransferStep& a_block_copy_step, + const BGridDesc& b_grid_desc, + const BBlockDesc& b_block_desc, + BBlockTransfer& b_blockwise_copy, + const BGridBuffer& b_grid_buf, + BBlockBuffer& b_block_buf, + const BBlockTransferStep& b_block_copy_step, + const BlockwiseGemm& blockwise_gemm, + CThreadBuffer& c_thread_buf, + index_t num_loop) + { + // global read 0 + a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf); + b_blockwise_copy.RunRead(b_grid_desc, b_grid_buf); + + a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step); + b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step); + + // Initialize C + c_thread_buf.Clear(); + + // LDS write 0 + a_blockwise_copy.RunWrite(a_block_desc, a_block_buf); + b_blockwise_copy.RunWrite(b_block_desc, b_block_buf); + + num_loop--; + + while(num_loop > 0) + { + a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf); + block_sync_lds(); + b_blockwise_copy.RunRead(b_grid_desc, b_grid_buf); + + blockwise_gemm.Run(a_block_buf, b_block_buf, c_thread_buf); + + block_sync_lds(); + + a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step); + b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step); + a_blockwise_copy.RunWrite(a_block_desc, a_block_buf); + b_blockwise_copy.RunWrite(b_block_desc, b_block_buf); + + num_loop--; + } + // tail + { + block_sync_lds(); + blockwise_gemm.Run(a_block_buf, b_block_buf, c_thread_buf); + } + } +}; + +} // namespace ck diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_reduce_xdl_cshuffle_v1.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_reduce_xdl_cshuffle_v1.hpp index a3f5324713c2d61e94434370ee9040ab12475279..99e410f688312b47dbeba4afb1d1712624561a76 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_reduce_xdl_cshuffle_v1.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_reduce_xdl_cshuffle_v1.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #pragma once @@ -55,7 +55,7 @@ __global__ void const Block2CTileMap block_2_ctile_map) { #if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \ - defined(__gfx940)) + defined(__gfx940) || defined(__gfx941__) || defined(__gfx942__)) __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()]; GridwiseGemm::template Run(p_a_grid, @@ -164,8 +164,8 @@ struct GridwiseGemmReduce_k0mk1_k0nk1_mn_xdl_cshuffle_v1 using ThisThreadBlock = ThisThreadBlock; - using GridwiseGemmPipe = remove_cvref_t())>; + using GridwiseGemmPipe = remove_cvref_t< + decltype(GridwiseGemmPipeline_Selector())>; __host__ __device__ static constexpr auto GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1() { @@ -318,8 +318,9 @@ struct GridwiseGemmReduce_k0mk1_k0nk1_mn_xdl_cshuffle_v1 c_grid_desc_m_n); } - using CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock = remove_cvref_t; + using CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock = + remove_cvref_t; using ReduceGridDescriptor_MBlock_MPerBlock = remove_cvref_t; diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_split_k_multiple_d_xdl_cshuffle.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_split_k_multiple_d_xdl_cshuffle.hpp index aa89bff9ee21873fdc57c839eb800be455005a9a..18cf80041b71c9f581462cef10f7fd3508e96930 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_split_k_multiple_d_xdl_cshuffle.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_split_k_multiple_d_xdl_cshuffle.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #pragma once @@ -375,10 +375,12 @@ struct GridwiseGemmSplitKMultipleD_xdl_cshuffle remove_cvref_t; using DefaultBGridDesc_BK0_N_BK1 = remove_cvref_t; - using EGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock = remove_cvref_t; - using DsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock = remove_cvref_t; + using EGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock = + remove_cvref_t; + using DsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock = + remove_cvref_t; using DefaultBlock2ETileMap = remove_cvref_t; diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_split_k_multiple_d_xdl_cshuffle_v2.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_split_k_multiple_d_xdl_cshuffle_v2.hpp new file mode 100644 index 0000000000000000000000000000000000000000..caf8f040f4a78232386a1f0c3d0be02b4dc63295 --- /dev/null +++ b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_split_k_multiple_d_xdl_cshuffle_v2.hpp @@ -0,0 +1,1076 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck/utility/common_header.hpp" +#include "ck/tensor_description/multi_index_transform_helper.hpp" +#include "ck/tensor_description/tensor_descriptor.hpp" +#include "ck/tensor_description/tensor_descriptor_helper.hpp" +#include "ck/tensor_operation/gpu/grid/block_to_ctile_map.hpp" +#include "ck/tensor_operation/gpu/grid/gridwise_gemm_pipeline_selector.hpp" +#include "ck/tensor_operation/gpu/block/blockwise_gemm_xdlops.hpp" +#include "ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v4r1.hpp" +#include "ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v7.hpp" +#include "ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" + +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/matrix_padder.hpp" + +namespace ck { + +// GEMM: +// input : A[M, K] +// input : B[N, K] +// input : D0[M, N], D1[M, N], ... +// output : E[M, N] +// C = a_op(A) * b_op(B) +// E = cde_op(C, D0, D1, ...) +// Assume: +// D0, D1, ... and E have the same layout +template +struct GridwiseGemmMultipleD_xdl_splitk_cshuffle +{ + static constexpr index_t NumDTensor = DsDataType::Size(); + + using GemmSpecialization = ck::tensor_operation::device::GemmSpecialization; + + static constexpr auto I0 = Number<0>{}; + static constexpr auto I1 = Number<1>{}; + static constexpr auto I2 = Number<2>{}; + static constexpr auto I3 = Number<3>{}; + static constexpr auto I4 = Number<4>{}; + static constexpr auto I5 = Number<5>{}; + static constexpr auto I6 = Number<6>{}; + static constexpr auto I7 = Number<7>{}; + + // K1 should be Number<...> + static constexpr auto AK1 = Number{}; + static constexpr auto BK1 = Number{}; + static constexpr auto AK0PerBlock = Number{}; + static constexpr auto BK0PerBlock = Number{}; + + using ThisThreadBlock = ThisThreadBlock; + + using GridwiseGemmPipe = remove_cvref_t< + decltype(GridwiseGemmPipeline_Selector())>; + + __host__ __device__ static constexpr auto GetABlockDescriptor_KBatch_AK0PerBlock_MPerBlock_AK1() + { + // A matrix in LDS memory, dst of blockwise copy + return make_naive_tensor_descriptor( + make_tuple(I1, AK0PerBlock, Number{}, AK1), + make_tuple(AK0PerBlock * Number{} * AK1, + Number{} * AK1, + AK1, + I1)); + } + + __host__ __device__ static constexpr auto GetBBlockDescriptor_KBatch_BK0PerBlock_NPerBlock_BK1() + { + // B matrix in LDS memory, dst of blockwise copy + return make_naive_tensor_descriptor( + make_tuple(I1, BK0PerBlock, Number{}, BK1), + make_tuple(BK0PerBlock * Number{} * BK1, + Number{} * BK1, + BK1, + I1)); + } + + __host__ __device__ static constexpr auto GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1() + { + // A matrix in LDS memory, dst of blockwise copy + return make_naive_tensor_descriptor( + make_tuple(AK0PerBlock, Number{}, AK1), + make_tuple(Number{} * AK1, AK1, I1)); + } + + __host__ __device__ static constexpr auto GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1() + { + // B matrix in LDS memory, dst of blockwise copy + return make_naive_tensor_descriptor( + make_tuple(BK0PerBlock, Number{}, BK1), + make_tuple(Number{} * BK1, BK1, I1)); + } + + __host__ __device__ static constexpr auto + GetCShuffleBlockDescriptor_MBlock_MPerBlock_NBlock_NPerBlock() + { + constexpr index_t MWave = MPerBlock / (MXdlPerWave * MPerXdl); + constexpr index_t NWave = NPerBlock / (NXdlPerWave * NPerXdl); + + constexpr auto c_shuffle_block_desc_mblock_mperblock_nblock_nperblock = + make_naive_tensor_descriptor_packed( + make_tuple(I1, + Number{}, + I1, + Number{})); + + return c_shuffle_block_desc_mblock_mperblock_nblock_nperblock; + } + + // ck::Tuple + static constexpr auto MakeDsGridPointer() + { + return generate_tuple( + [&](auto i) { + using DDataType = remove_cvref_t>; + + return static_cast(nullptr); + }, + Number{}); + } + + __host__ __device__ static constexpr index_t GetSharedMemoryNumberOfByte() + { + // LDS allocation for A and B: be careful of alignment + constexpr auto a_block_desc_ak0_m_ak1 = GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1(); + constexpr auto b_block_desc_bk0_n_bk1 = GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1(); + + // lds max alignment + constexpr auto max_lds_align = math::lcm(AK1, BK1); + + constexpr auto a_block_space_size_aligned = math::integer_least_multiple( + a_block_desc_ak0_m_ak1.GetElementSpaceSize(), max_lds_align); + + constexpr auto b_block_space_size_aligned = math::integer_least_multiple( + b_block_desc_bk0_n_bk1.GetElementSpaceSize(), max_lds_align); + + // LDS allocation for C shuffle in LDS + constexpr auto c_shuffle_block_desc_mblock_mperblock_nblock_nperblock = + GetCShuffleBlockDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(); + + constexpr auto c_block_size = + c_shuffle_block_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize(); + + return math::max(a_block_space_size_aligned * sizeof(ADataType) + + b_block_space_size_aligned * sizeof(BDataType), + c_block_size * sizeof(CShuffleDataType)); + } + + __host__ __device__ static auto CalculateMPadded(index_t M) + { + return math::integer_least_multiple(M, MPerBlock); + } + + __host__ __device__ static auto CalculateNPadded(index_t N) + { + return math::integer_least_multiple(N, NPerBlock); + } + + __host__ __device__ static auto CalculateKPadded(index_t K, index_t K_Batch) + { + return math::integer_least_multiple(K, KPerBlock * K_Batch); + } + + template + __host__ __device__ static auto + MakeAGridDescriptor_KBatch_AK0_M_AK1(index_t M, index_t K, index_t StrideA, index_t KBatch) + { + const auto a_grid_desc_m_k = [&]() { + if constexpr(is_same::value) + { + return make_naive_tensor_descriptor(make_tuple(M, K), make_tuple(StrideA, I1)); + } + else if constexpr(is_same::value) + { + return make_naive_tensor_descriptor(make_tuple(M, K), make_tuple(I1, StrideA)); + } + }(); + + const auto MPad = CalculateMPadded(M); + const auto KPad = CalculateKPadded(K, KBatch); + + const auto a_grid_desc_m_kpad = transform_tensor_descriptor( + a_grid_desc_m_k, + make_tuple(make_pass_through_transform(M), make_right_pad_transform(K, KPad - K)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + + const auto AK0 = KPad / (KBatch * AK1); + + if constexpr(GemmSpec == tensor_operation::device::GemmSpecialization::MPadding || + GemmSpec == tensor_operation::device::GemmSpecialization::MNPadding || + GemmSpec == tensor_operation::device::GemmSpecialization::MKPadding || + GemmSpec == tensor_operation::device::GemmSpecialization::MNKPadding) + { + // const auto PadM = (MPerBlock - M % MPerBlock) % MPerBlock; + return transform_tensor_descriptor( + a_grid_desc_m_kpad, + make_tuple(make_unmerge_transform(make_tuple(KBatch, AK0, AK1)), + make_right_pad_transform(M, MPad - M)), + make_tuple(Sequence<1>{}, Sequence<0>{}), + make_tuple(Sequence<0, 1, 3>{}, Sequence<2>{})); + } + else + { + return transform_tensor_descriptor( + a_grid_desc_m_kpad, + make_tuple(make_unmerge_transform(make_tuple(KBatch, AK0, AK1)), + make_pass_through_transform(M)), + make_tuple(Sequence<1>{}, Sequence<0>{}), + make_tuple(Sequence<0, 1, 3>{}, Sequence<2>{})); + } + } + + template + __host__ __device__ static auto + MakeBGridDescriptor_KBatch_BK0_N_BK1(index_t K, index_t N, index_t StrideB, index_t KBatch) + { + const auto b_grid_desc_k_n = [&]() { + if constexpr(is_same::value) + { + return make_naive_tensor_descriptor(make_tuple(K, N), make_tuple(StrideB, I1)); + } + else if constexpr(is_same::value) + { + return make_naive_tensor_descriptor(make_tuple(K, N), make_tuple(I1, StrideB)); + } + }(); + + const auto NPad = CalculateNPadded(N); + const auto KPad = CalculateKPadded(K, KBatch); + + const auto b_grid_desc_kpad_n = transform_tensor_descriptor( + b_grid_desc_k_n, + make_tuple(make_right_pad_transform(K, KPad - K), make_pass_through_transform(N)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + + const auto BK0 = KPad / (KBatch * BK1); + + if constexpr(GemmSpec == tensor_operation::device::GemmSpecialization::NPadding || + GemmSpec == tensor_operation::device::GemmSpecialization::MNPadding || + GemmSpec == tensor_operation::device::GemmSpecialization::NKPadding || + GemmSpec == tensor_operation::device::GemmSpecialization::MNKPadding) + { + // const auto PadN = (NPerBlock - N % NPerBlock) % NPerBlock; + return transform_tensor_descriptor( + b_grid_desc_kpad_n, + make_tuple(make_unmerge_transform(make_tuple(KBatch, BK0, BK1)), + make_right_pad_transform(N, NPad - N)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0, 1, 3>{}, Sequence<2>{})); + } + else + { + return transform_tensor_descriptor( + b_grid_desc_kpad_n, + make_tuple(make_unmerge_transform(make_tuple(KBatch, BK0, BK1)), + make_pass_through_transform(N)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0, 1, 3>{}, Sequence<2>{})); + } + } + + // E desc for destination in blockwise copy + template + __host__ __device__ static constexpr auto + MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(const EGridDesc_M_N& e_grid_desc_m_n) + { + const auto M = e_grid_desc_m_n.GetLength(I0); + const auto N = e_grid_desc_m_n.GetLength(I1); + + const auto MBlock = M / MPerBlock; + const auto NBlock = N / NPerBlock; + + const auto e_grid_desc_mblock_mperblock_nblock_nperblock = transform_tensor_descriptor( + e_grid_desc_m_n, + make_tuple(make_unmerge_transform(make_tuple(MBlock, Number{})), + make_unmerge_transform(make_tuple(NBlock, Number{}))), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0, 1>{}, Sequence<2, 3>{})); + + return e_grid_desc_mblock_mperblock_nblock_nperblock; + } + + // Ds desc for source in blockwise copy + template + __host__ __device__ static constexpr auto + MakeDsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(const DsGridDesc_M_N& ds_grid_desc_m_n) + { + return generate_tuple( + [&](auto i) { + return MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(ds_grid_desc_m_n[i]); + }, + Number{}); + } + + // return block_id to E matrix tile idx (m0, n0) mapping + template + __host__ __device__ static constexpr auto + MakeDefaultBlock2ETileMap(const EGridDesc_M_N& e_grid_desc_m_n) + { + return BlockToCTileMap_M00_N0_M01Adapt( + e_grid_desc_m_n); + } + + template + __host__ __device__ static constexpr bool + CheckValidity(const index_t M, + const index_t N, + const index_t K, + const index_t StrideA, + const index_t StrideB, + const std::array StrideDs, + const index_t StrideE, + const index_t KBatch) + { + const auto a_grid_desc_kbatch_ak0_m_ak1 = + MakeAGridDescriptor_KBatch_AK0_M_AK1(M, K, StrideA, KBatch); + const auto b_grid_desc_kbatch_bk0_n_bk1 = + MakeBGridDescriptor_KBatch_BK0_N_BK1(K, N, StrideB, KBatch); + + ignore = StrideDs; + + const auto e_grid_desc_m_n = MakeEGridDescriptor_M_N(M, N, StrideE); + +#if 0 + // check tile size + if(!(M % MPerBlock == 0 && N % NPerBlock == 0 && K % KPerBlock == 0)) + { + return false; + } +#endif + + // check gridwise gemm pipeline + const auto num_k_loop = K / KPerBlock; + + if(!GridwiseGemmPipe::IsSupported(num_k_loop)) + { + return false; + } + + // TODO: also check validity of all components (blockwise-copy, threadwise-copy, etc) + // check tensor size: cannot be larger than 2GB each + constexpr long_index_t TwoGB = (long_index_t{1} << 31); + + if(!(a_grid_desc_kbatch_ak0_m_ak1.GetElementSpaceSize() * sizeof(ADataType) <= TwoGB && + b_grid_desc_kbatch_bk0_n_bk1.GetElementSpaceSize() * sizeof(BDataType) <= TwoGB && + e_grid_desc_m_n.GetElementSpaceSize() * sizeof(EDataType) <= TwoGB)) + { + return false; + } + + return true; + } + + __host__ __device__ static constexpr bool CalculateHasMainKBlockLoop(index_t K) + { + const index_t num_loop = K / KPerBlock; + + return GridwiseGemmPipe::CalculateHasMainLoop(num_loop); + } + + using DsGridPointer = decltype(MakeDsGridPointer()); + + template + __host__ __device__ static auto + MakeEGridDescriptor_M_N(index_t MRaw, index_t NRaw, index_t StrideE) + { + constexpr auto matrix_padder = + ck::tensor_operation::device::MatrixPadder{ + MPerBlock, NPerBlock, KPerBlock}; + const auto e_grid_desc_mraw_nraw = [&]() { + if constexpr(is_same::value) + { + return make_naive_tensor_descriptor(make_tuple(MRaw, NRaw), + make_tuple(StrideE, I1)); + } + else if constexpr(is_same::value) + { + return make_naive_tensor_descriptor(make_tuple(MRaw, NRaw), + make_tuple(I1, StrideE)); + } + }(); + + return matrix_padder.PadCDescriptor_M_N(e_grid_desc_mraw_nraw); + } + + template + __host__ __device__ static auto + MakeDsGridDescriptor_M_N(const std::array& MRaws, + const std::array& NRaws, + const std::array& DsStride) + { + return generate_tuple( + [&](auto i) { + using DLayout = remove_cvref_t>; + + return MakeEGridDescriptor_M_N(MRaws[i], NRaws[i], DsStride[i]); + }, + Number{}); + } + + __device__ __host__ static constexpr auto GetMPerBlock() { return MPerBlock; } + + template + __device__ static void Run(const ADataType* __restrict__ p_a_grid, + const BDataType* __restrict__ p_b_grid, + DsGridPointer p_ds_grid, + EDataType* __restrict__ p_e_grid, + void* __restrict__ p_shared, + uint32_t* barrier_count_finished, + const index_t KBatch, + const AElementwiseOperation& a_element_op, + const BElementwiseOperation& b_element_op, + const CDEElementwiseOperation_& cde_element_op, + const AGridDesc_KBatch_AK0_M_AK1& a_grid_desc_kbatch_ak0_m_ak1, + const BGridDesc_KBatch_BK0_N_BK1& b_grid_desc_kbatch_bk0_n_bk1, + const DsGridDesc_MBlock_MPerBlock_NBlock_NPerBlock& + ds_grid_desc_mblock_mperblock_nblock_nperblock, + const EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock& + e_grid_desc_mblock_mperblock_nblock_nperblock, + const Block2ETileMap& block_2_etile_map) + { + const auto a_grid_buf = make_dynamic_buffer( + p_a_grid, a_grid_desc_kbatch_ak0_m_ak1.GetElementSpaceSize()); + + const auto b_grid_buf = make_dynamic_buffer( + p_b_grid, b_grid_desc_kbatch_bk0_n_bk1.GetElementSpaceSize()); + + const auto ds_grid_buf = generate_tuple( + [&](auto i) { + return make_dynamic_buffer( + p_ds_grid[i], + ds_grid_desc_mblock_mperblock_nblock_nperblock[i].GetElementSpaceSize()); + }, + Number{}); + + auto e_grid_buf = make_dynamic_buffer( + p_e_grid, e_grid_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize()); + + // divide block work by [M, N] + const auto block_work_idx = + block_2_etile_map.CalculateBottomIndex(make_multi_index(get_block_1d_id())); + + // HACK: this force m/n_block_data_idx_on_grid into SGPR + const index_t kbatch_id = __builtin_amdgcn_readfirstlane(block_work_idx[I0]); + + const index_t m_block_data_idx_on_grid = + __builtin_amdgcn_readfirstlane(block_work_idx[I1] * MPerBlock); + + const index_t n_block_data_idx_on_grid = + __builtin_amdgcn_readfirstlane(block_work_idx[I2] * NPerBlock); + + // lds max alignment + constexpr auto max_lds_align = math::lcm(AK1, BK1); + + // A matrix in LDS memory, dst of blockwise copy + constexpr auto a_block_desc_kbatch_ak0_m_ak1 = + GetABlockDescriptor_KBatch_AK0PerBlock_MPerBlock_AK1(); + + // B matrix in LDS memory, dst of blockwise copy + constexpr auto b_block_desc_kbatch_bk0_n_bk1 = + GetBBlockDescriptor_KBatch_BK0PerBlock_NPerBlock_BK1(); + + // A matrix blockwise copy + auto a_blockwise_copy = + ThreadGroupTensorSliceTransfer_v4r1, + ABlockTransferThreadClusterLengths_KBatch_AK0_M_AK1, + ABlockTransferThreadClusterArrangeOrder, + ADataType, + ComputeType, + decltype(a_grid_desc_kbatch_ak0_m_ak1), + decltype(a_block_desc_kbatch_ak0_m_ak1), + ABlockTransferSrcAccessOrder, + Sequence<2, 0, 1, 3>, + ABlockTransferSrcVectorDim, + 3, + ABlockTransferSrcScalarPerVector, + ABlockTransferDstScalarPerVector_AK1, + 1, + 1, + AThreadTransferSrcResetCoordinateAfterRun, + true, + NumGemmKPrefetchStage>( + a_grid_desc_kbatch_ak0_m_ak1, + make_multi_index(kbatch_id, 0, m_block_data_idx_on_grid, 0), + a_element_op, + a_block_desc_kbatch_ak0_m_ak1, + make_multi_index(0, 0, 0, 0), + ck::tensor_operation::element_wise::PassThrough{}); + + // B matrix blockwise copy + auto b_blockwise_copy = + ThreadGroupTensorSliceTransfer_v4r1, + BBlockTransferThreadClusterLengths_KBatch_BK0_N_BK1, + BBlockTransferThreadClusterArrangeOrder, + BDataType, + ComputeType, + decltype(b_grid_desc_kbatch_bk0_n_bk1), + decltype(b_block_desc_kbatch_bk0_n_bk1), + BBlockTransferSrcAccessOrder, + Sequence<2, 0, 1, 3>, + BBlockTransferSrcVectorDim, + 3, + BBlockTransferSrcScalarPerVector, + BBlockTransferDstScalarPerVector_BK1, + 1, + 1, + BThreadTransferSrcResetCoordinateAfterRun, + true, + NumGemmKPrefetchStage>( + b_grid_desc_kbatch_bk0_n_bk1, + make_multi_index(kbatch_id, 0, n_block_data_idx_on_grid, 0), + b_element_op, + b_block_desc_kbatch_bk0_n_bk1, + make_multi_index(0, 0, 0, 0), + ck::tensor_operation::element_wise::PassThrough{}); + + // A matrix in LDS memory, dst of blockwise copy + constexpr auto a_block_desc_ak0_m_ak1 = GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1(); + + // B matrix in LDS memory, dst of blockwise copy + constexpr auto b_block_desc_bk0_n_bk1 = GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1(); + + // GEMM definition + // c_mtx += transpose(a_mtx) * b_mtx + // a_mtx[K0PerBlock, MPerBlock] is in LDS + // b_mtx[K0PerBlock, NPerBlock] is in LDS + // c_mtx[MPerBlock, NPerBlock] is distributed among threads, and saved in + // register + // sanity check + constexpr index_t KPack = + math::max(math::lcm(AK1, BK1), + MfmaSelector::selected_mfma.k_per_blk); + + auto blockwise_gemm = BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_Selector< + BlockSize, + ComputeType, + AccDataType, + decltype(a_block_desc_ak0_m_ak1), + decltype(b_block_desc_bk0_n_bk1), + MPerXdl, + NPerXdl, + MXdlPerWave, + NXdlPerWave, + KPack, + LoopSched>(); + +#if 1 + if(block_work_idx[I0] == 0) + { + const index_t nThreadSize = CDEShuffleBlockTransferScalarPerVector_NPerBlock; + const index_t numNThreads = NPerBlock / nThreadSize; + const index_t numMThreads = BlockSize / numNThreads; + const index_t mThreadSize = MPerBlock / numMThreads; + + const index_t m_tid = get_thread_local_1d_id() / numNThreads; + const index_t n_tid = get_thread_local_1d_id() % numNThreads; + + auto c_thread_desc_mblock_mperblock_nblock_nperblock = + make_naive_tensor_descriptor_packed( + make_tuple(I1, Number{}, I1, Number{})); + + StaticBuffer + e_thread_zero_buf; + + auto c_thread_copy = ThreadwiseTensorSliceTransfer_v1r3< + EDataType, + EDataType, + decltype(c_thread_desc_mblock_mperblock_nblock_nperblock), + decltype(e_grid_desc_mblock_mperblock_nblock_nperblock), + ck::tensor_operation::element_wise::PassThrough, + Sequence<1, mThreadSize, 1, nThreadSize>, + Sequence<0, 1, 2, 3>, + 3, + CDEShuffleBlockTransferScalarPerVector_NPerBlock, + InMemoryDataOperationEnum::Set, + 1, + true>{e_grid_desc_mblock_mperblock_nblock_nperblock, + make_multi_index(block_work_idx[I1], + m_tid * mThreadSize, + block_work_idx[I2], + n_tid * nThreadSize), + ck::tensor_operation::element_wise::PassThrough{}}; + + c_thread_copy.Run(c_thread_desc_mblock_mperblock_nblock_nperblock, + make_tuple(I0, I0, I0, I0), + e_thread_zero_buf, + e_grid_desc_mblock_mperblock_nblock_nperblock, + e_grid_buf); + + __syncthreads(); + + if(threadIdx.x == 0) + { + atomicAdd(barrier_count_finished, 1); + } + } +#endif + + auto c_thread_buf = blockwise_gemm.GetCThreadBuffer(); + + // LDS allocation for A and B: be careful of alignment + constexpr auto a_block_space_size_aligned = math::integer_least_multiple( + a_block_desc_ak0_m_ak1.GetElementSpaceSize(), max_lds_align); + + auto a_block_buf = make_dynamic_buffer( + static_cast(p_shared), a_block_desc_ak0_m_ak1.GetElementSpaceSize()); + + auto b_block_buf = make_dynamic_buffer( + static_cast(p_shared) + a_block_space_size_aligned, + b_block_desc_bk0_n_bk1.GetElementSpaceSize()); + + constexpr auto a_block_slice_copy_step = make_multi_index(0, KPerBlock / AK1, 0, 0); + constexpr auto b_block_slice_copy_step = make_multi_index(0, KPerBlock / BK1, 0, 0); + + // gridwise GEMM pipeline + const auto gridwise_gemm_pipeline = + GridwiseGemmPipeline_Selector(); + + const index_t num_k_block_main_loop = + __builtin_amdgcn_readfirstlane((a_grid_desc_kbatch_ak0_m_ak1.GetLength(I1) * + a_grid_desc_kbatch_ak0_m_ak1.GetLength(I3)) / + KPerBlock); + + gridwise_gemm_pipeline.template Run(a_grid_desc_kbatch_ak0_m_ak1, + a_block_desc_kbatch_ak0_m_ak1, + a_blockwise_copy, + a_grid_buf, + a_block_buf, + a_block_slice_copy_step, + b_grid_desc_kbatch_bk0_n_bk1, + b_block_desc_kbatch_bk0_n_bk1, + b_blockwise_copy, + b_grid_buf, + b_block_buf, + b_block_slice_copy_step, + blockwise_gemm, + c_thread_buf, + num_k_block_main_loop); + + // shuffle C and write out + { + if(threadIdx.x == 0) + { + while(__atomic_load_n(barrier_count_finished, __ATOMIC_RELAXED) == 0) {} + } + + __syncthreads(); + + static_assert(MXdlPerWave % CShuffleMXdlPerWavePerShuffle == 0 && + NXdlPerWave % CShuffleNXdlPerWavePerShuffle == 0, + "wrong!"); + + constexpr index_t MWave = MPerBlock / (MXdlPerWave * MPerXdl); + constexpr index_t NWave = NPerBlock / (NXdlPerWave * NPerXdl); + + // TODO: hacky, fix it! + constexpr auto c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2 = + blockwise_gemm.GetCThreadDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(); + + // TODO: hacky, fix it! + // c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp is only used to get lengths + constexpr auto c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp = + blockwise_gemm.GetCBlockDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(); + + constexpr auto M0 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I0); + constexpr auto N0 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I1); + constexpr auto M1 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I2); + constexpr auto N1 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I3); + constexpr auto M2 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I4); + constexpr auto M3 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I5); + constexpr auto M4 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I6); + constexpr auto N2 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I7); + + constexpr auto c_shuffle_block_desc_mblock_mperblock_nblock_nperblock = + GetCShuffleBlockDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(); + + auto c_shuffle_block_buf = make_dynamic_buffer( + static_cast(p_shared), + c_shuffle_block_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize()); + + constexpr auto c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2 = transform_tensor_descriptor( + c_shuffle_block_desc_mblock_mperblock_nblock_nperblock, + make_tuple( + make_freeze_transform(I0), + make_unmerge_transform(make_tuple( + Number{}, // M0 (MXdlPerWave) per shuffle + M1, // M1 = MWave + M2, // M2 * M3 * M4 = MPerXdl + M3, + M4)), + make_freeze_transform(I0), + make_unmerge_transform(make_tuple( + Number{}, // N0 (NXdlPerWave) per shuffle + N1, // N1 = NWave + N2))), // N2 = NPerXdl + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), + make_tuple( + Sequence<>{}, Sequence<0, 2, 4, 5, 6>{}, Sequence<>{}, Sequence<1, 3, 7>{})); + + // calculate origin of thread output tensor on global memory + // blockwise GEMM c matrix starting index + const auto c_thread_mtx_on_block = + blockwise_gemm.CalculateCThreadOriginDataIndex(I0, I0, I0, I0); + + const index_t m_thread_data_on_block = c_thread_mtx_on_block[I0]; + const index_t n_thread_data_on_block = c_thread_mtx_on_block[I1]; + + const auto m_thread_data_on_block_to_m0_m1_m2_m3_m4_adaptor = + make_single_stage_tensor_adaptor( + make_tuple(make_merge_transform(make_tuple(M0, M1, M2, M3, M4))), + make_tuple(Sequence<0, 1, 2, 3, 4>{}), + make_tuple(Sequence<0>{})); + + const auto m_thread_data_on_block_idx = + m_thread_data_on_block_to_m0_m1_m2_m3_m4_adaptor.CalculateBottomIndex( + make_multi_index(m_thread_data_on_block)); + + const auto n_thread_data_on_block_to_n0_n1_n2_adaptor = + make_single_stage_tensor_adaptor( + make_tuple(make_merge_transform(make_tuple(N0, N1, N2))), + make_tuple(Sequence<0, 1, 2>{}), + make_tuple(Sequence<0>{})); + + const auto n_thread_data_on_block_idx = + n_thread_data_on_block_to_n0_n1_n2_adaptor.CalculateBottomIndex( + make_multi_index(n_thread_data_on_block)); + + // shuffle: threadwise copy C from VGPR to LDS + auto c_thread_copy_vgpr_to_lds = + ThreadwiseTensorSliceTransfer_v1r3, + Sequence<0, 1, 2, 3, 4, 5, 6, 7>, + 7, + 1, + InMemoryDataOperationEnum::Set, + 1, + true>{ + c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2, + make_multi_index(0, + 0, + m_thread_data_on_block_idx[I1], + n_thread_data_on_block_idx[I1], + m_thread_data_on_block_idx[I2], + m_thread_data_on_block_idx[I3], + m_thread_data_on_block_idx[I4], + n_thread_data_on_block_idx[I2]), + ck::tensor_operation::element_wise::PassThrough{}}; + + // tuple of reference to C/Ds tensor descriptors + const auto c_ds_desc_refs = concat_tuple_of_reference( + tie(c_shuffle_block_desc_mblock_mperblock_nblock_nperblock), + generate_tie( + [&](auto i) -> const auto& // return type should be reference + { return ds_grid_desc_mblock_mperblock_nblock_nperblock[i]; }, + Number{})); + + // tuple of reference to C/Ds tensor descriptors + const auto c_ds_buf_refs = concat_tuple_of_reference( + tie(c_shuffle_block_buf), + generate_tie( + [&](auto i) -> const auto& // return type should be reference + { return ds_grid_buf[i]; }, + Number{})); + + // tuple of starting index of C/Ds blockwise copy + const auto idx_c_ds_block_begin = container_concat( + make_tuple(make_multi_index(0, 0, 0, 0)), + generate_tuple( + [&](auto) { + return make_multi_index(block_work_idx[I1], 0, block_work_idx[I2], 0); + }, + Number{})); + + // space filling curve for threadwise C in VGPR before shuffle + constexpr auto sfc_c_vgpr = + SpaceFillingCurve, + Sequence<0, 1, 2, 3, 4, 5, 6, 7>, + Sequence>{}; + + // space filling curve for shuffled blockwise C/D/E + constexpr auto sfc_cde_block = + SpaceFillingCurve, + Sequence<0, 2, 1, 3>, + Sequence<1, + CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl, + 1, + CShuffleNXdlPerWavePerShuffle * NWave * NPerXdl>>{}; + + constexpr index_t num_access = sfc_c_vgpr.GetNumOfAccess(); + + static_assert(num_access == sfc_cde_block.GetNumOfAccess(), "wrong!"); + + // blockwise copy C/D/E between LDS and global + auto cde_block_copy_lds_and_global = ThreadGroupTensorSliceTransfer_v7< + ThisThreadBlock, + decltype(container_concat(make_tuple(CShuffleDataType{}), DsDataType_{})), + Tuple, + decltype(c_ds_desc_refs), + decltype(tie(e_grid_desc_mblock_mperblock_nblock_nperblock)), + CDEElementwiseOperation_, + Sequence(EGlobalMemoryDataOperation)>, // FIXME: make + // Sequence support + // arbitray type + Sequence<1, + CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl, + 1, + CShuffleNXdlPerWavePerShuffle * NWave * NPerXdl>, // BlockSliceLengths, + CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, + Sequence<0, 1, 2, 3>, // typename ThreadClusterArrangeOrder, + Sequence<0, 1, 2, 3>, // typename DimAccessOrder, + 3, // index_t VectorDim, + CDEShuffleBlockTransferScalarPerVector_NPerBlock, + sequence_merge_t< + Sequence, + uniform_sequence_gen_t>, // ThreadTransferSrcResetCoordinateAfterRunFlags + Sequence> // ThreadTransferDstResetCoordinateAfterRunFlags + {c_ds_desc_refs, + idx_c_ds_block_begin, + tie(e_grid_desc_mblock_mperblock_nblock_nperblock), + make_tuple(make_multi_index(block_work_idx[I1], 0, block_work_idx[I2], 0)), + cde_element_op}; + + static_for<0, num_access, 1>{}([&](auto access_id) { + // make sure it's safe to write to LDS + block_sync_lds(); + + // each thread write its data from VGPR to LDS + c_thread_copy_vgpr_to_lds.Run(c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2, + sfc_c_vgpr.GetIndexTupleOfNumber(access_id), + c_thread_buf, + c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2, + c_shuffle_block_buf); + + // make sure it's safe to read from LDS + block_sync_lds(); + + // each block copy its data from LDS to global + cde_block_copy_lds_and_global.Run( + c_ds_desc_refs, + c_ds_buf_refs, + tie(e_grid_desc_mblock_mperblock_nblock_nperblock), + tie(e_grid_buf)); + + if constexpr(access_id < num_access - 1) + { + constexpr auto cde_lds_and_global_step = + sfc_cde_block.GetForwardStep(access_id); + + // move on Ds + static_for<0, NumDTensor_, 1>{}([&](auto i) { + cde_block_copy_lds_and_global.MoveSrcSliceWindow( + c_ds_desc_refs, i + I1, cde_lds_and_global_step); + }); + + // move on E + cde_block_copy_lds_and_global.MoveDstSliceWindow( + tie(e_grid_desc_mblock_mperblock_nblock_nperblock), + I0, + cde_lds_and_global_step); + } + }); + + if(threadIdx.x == 0) + { + index_t k_id_finished_t = atomicAdd(barrier_count_finished, 1); + + if(k_id_finished_t == KBatch) + { + *barrier_count_finished = 0; + } + } + } + } + + template + __device__ static void Run(const void* __restrict__ p_a_grid_, + const void* __restrict__ p_b_grid_, + DsGridPointer p_ds_grid, + void* __restrict__ p_e_grid_, + void* __restrict__ p_shared, + uint32_t* barrier_count_finished, + const AElementwiseOperation& a_element_op, + const BElementwiseOperation& b_element_op, + const CDEElementwiseOperation& cde_element_op, + const index_t M, + const index_t N, + const index_t K, + const index_t StrideA, + const index_t StrideB, + const std::array StrideDs, + const index_t StrideE, + const index_t KBatch, + const Block2ETileMap& block_2_etile_map) + { + const auto p_a_grid = reinterpret_cast(p_a_grid_); + const auto p_b_grid = reinterpret_cast(p_b_grid_); + const auto p_e_grid = reinterpret_cast(p_e_grid_); + + using DsGridDesc_M_N = + remove_cvref_t({}, {}, {}))>; + + DsGridDesc_M_N ds_grid_desc_m_n; + + static_for<0, NumDTensor, 1>{}([&](auto j) { + using DLayout = remove_cvref_t>; + + ds_grid_desc_m_n(j) = MakeEGridDescriptor_M_N(M, N, StrideDs[j]); + }); + + const auto e_grid_desc_m_n = MakeEGridDescriptor_M_N(M, N, StrideE); + + // tensor descriptors for block/thread-wise copy + const auto a_grid_desc_kbatch_ak0_m_ak1 = + MakeAGridDescriptor_KBatch_AK0_M_AK1(M, K, StrideA, KBatch); + + const auto b_grid_desc_kbatch_bk0_n_bk1 = + MakeBGridDescriptor_KBatch_BK0_N_BK1(K, N, StrideB, KBatch); + + using DsGridDesc_MBlock_MPerBlock_NBlock_NPerBlock = + remove_cvref_t; + + DsGridDesc_MBlock_MPerBlock_NBlock_NPerBlock ds_grid_desc_mblock_mperblock_nblock_nperblock; + + static_for<0, NumDTensor, 1>{}([&](auto j) { + ds_grid_desc_mblock_mperblock_nblock_nperblock(j) = + MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(ds_grid_desc_m_n[j]); + }); + + const auto e_grid_desc_mblock_mperblock_nblock_nperblock = + MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(e_grid_desc_m_n); + + const auto block_work_idx = + block_2_etile_map.CalculateBottomIndex(make_multi_index(get_block_1d_id())); + + const index_t kbatch_id = __builtin_amdgcn_readfirstlane(block_work_idx[I0]); + + if(kbatch_id == KBatch - 1) + { + Run( + p_a_grid, + p_b_grid, + p_ds_grid, + p_e_grid, + p_shared, + barrier_count_finished, + KBatch, + a_element_op, + b_element_op, + cde_element_op, + a_grid_desc_kbatch_ak0_m_ak1, + b_grid_desc_kbatch_bk0_n_bk1, + ds_grid_desc_mblock_mperblock_nblock_nperblock, + e_grid_desc_mblock_mperblock_nblock_nperblock, + block_2_etile_map); + } + else + { + Run>( + p_a_grid, + p_b_grid, + p_ds_grid, + p_e_grid, + p_shared, + barrier_count_finished, + KBatch, + a_element_op, + b_element_op, + ck::tensor_operation::element_wise::PassThrough{}, + a_grid_desc_kbatch_ak0_m_ak1, + b_grid_desc_kbatch_bk0_n_bk1, + ds_grid_desc_mblock_mperblock_nblock_nperblock, + e_grid_desc_mblock_mperblock_nblock_nperblock, + block_2_etile_map); + } + } +}; + +} // namespace ck diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_waveletmodel.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_waveletmodel.hpp index 2d3a36fca08d7811e406a89b398f401b741a7649..de5a4241986ff64e1797f2304da117178059ea4b 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_waveletmodel.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_waveletmodel.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_wmma.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_wmma.hpp index 397ae1c1b9b7b9763b7eb39855ddf4c752db949d..d8b31311b1040f6c77e955550875a2a3e1337530 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_wmma.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_wmma.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #pragma once @@ -138,8 +138,8 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_wmma using ThisThreadBlock = ThisThreadBlock; - using GridwiseGemmPipe = remove_cvref_t())>; + using GridwiseGemmPipe = remove_cvref_t< + decltype(GridwiseGemmPipeline_Selector())>; __host__ __device__ static constexpr auto GetABlockDescriptor_K0PerBlock_MPerBlock_K1() { @@ -308,8 +308,9 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_wmma c_grid_desc_m_n); } - using CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock = remove_cvref_t; + using CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock = + remove_cvref_t; using DefaultBlock2CTileMap = remove_cvref_t; diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v1.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v1.hpp index e6303d76c1fdc8cbff72633bb097987ebdfc0224..9c09f3a53912c1fb703292c6ae6be7782fc1d7b1 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v1.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v1.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #pragma once @@ -25,7 +25,7 @@ __global__ void kernel_gemm_xdl_cshuffle_v1(typename GridwiseGemm::Argument karg) { #if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \ - defined(__gfx940__)) + defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__)) __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()]; GridwiseGemm::template Run( @@ -35,18 +35,22 @@ __global__ void #endif // end of if (defined(__gfx908__) || defined(__gfx90a__)) } -template +template __global__ void #if CK_USE_LAUNCH_BOUNDS __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) #endif - kernel_gemm_xdl_cshuffle_v1(const FloatAB* __restrict__ p_a_grid, - const FloatAB* __restrict__ p_b_grid, + kernel_gemm_xdl_cshuffle_v1(const FloatA* __restrict__ p_a_grid, + const FloatB* __restrict__ p_b_grid, FloatC* __restrict__ p_c_grid, typename GridwiseGemm::Problem problem) { #if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \ - defined(__gfx940__)) + defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__)) __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()]; GridwiseGemm::template Run(p_a_grid, p_b_grid, p_c_grid, p_shared, problem); @@ -61,7 +65,8 @@ __global__ void template + PipelineVersion PipelineVer = PipelineVersion::v1, + typename ComputeType = FloatC> struct GridwiseGemm_k0mk1_k0nk1_mn_xdl_cshuffle_v1 { static constexpr auto I0 = Number<0>{}; @@ -463,8 +469,8 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdl_cshuffle_v1 // Argument struct Argument : public tensor_operation::device::BaseArgument, public Problem { - __host__ Argument(const FloatAB* p_a_grid_, - const FloatAB* p_b_grid_, + __host__ Argument(const FloatA* p_a_grid_, + const FloatB* p_b_grid_, FloatC* p_c_grid_, index_t M_, index_t N_, @@ -479,14 +485,14 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdl_cshuffle_v1 { } - const FloatAB* p_a_grid; - const FloatAB* p_b_grid; + const FloatA* p_a_grid; + const FloatB* p_b_grid; FloatC* p_c_grid; }; // FIXME: pass GridwiseGemmPipe as a template arguement into GridwiseGemm - using GridwiseGemmPipe = remove_cvref_t())>; + using GridwiseGemmPipe = remove_cvref_t< + decltype(GridwiseGemmPipeline_Selector())>; __device__ static constexpr auto GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1() { @@ -541,8 +547,8 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdl_cshuffle_v1 constexpr auto c_block_size = c_shuffle_block_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize(); - return math::max((a_block_space_size_aligned + b_block_space_size_aligned) * - sizeof(FloatAB), + return math::max((a_block_space_size_aligned * sizeof(ComputeType) + + b_block_space_size_aligned * sizeof(ComputeType)), c_block_size * sizeof(FloatCShuffle)); } @@ -676,8 +682,8 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdl_cshuffle_v1 using Block2CTileMap = BlockToCTileMap_M00_N0_M01Adapt; template - __device__ static void Run(const FloatAB* __restrict__ p_a_grid, - const FloatAB* __restrict__ p_b_grid, + __device__ static void Run(const FloatA* __restrict__ p_a_grid, + const FloatB* __restrict__ p_b_grid, FloatC* __restrict__ p_c_grid, void* __restrict__ p_shared, const Problem& problem) @@ -743,8 +749,8 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdl_cshuffle_v1 Sequence, ABlockTransferThreadClusterLengths_AK0_M_AK1, ABlockTransferThreadClusterArrangeOrder, - FloatAB, - FloatAB, + FloatA, + ComputeType, decltype(a_grid_desc_ak0_m_ak1), decltype(a_block_desc_ak0_m_ak1), ABlockTransferSrcAccessOrder, @@ -774,8 +780,8 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdl_cshuffle_v1 Sequence, BBlockTransferThreadClusterLengths_BK0_N_BK1, BBlockTransferThreadClusterArrangeOrder, - FloatAB, - FloatAB, + FloatB, + ComputeType, decltype(b_grid_desc_bk0_n_bk1), decltype(b_block_desc_bk0_n_bk1), BBlockTransferSrcAccessOrder, @@ -805,11 +811,11 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdl_cshuffle_v1 // sanity check constexpr index_t KPack = math::max(math::lcm(AK1Number, BK1Number), - MfmaSelector::selected_mfma.k_per_blk); + MfmaSelector::selected_mfma.k_per_blk); auto blockwise_gemm = BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_Selector< BlockSize, - FloatAB, + ComputeType, FloatGemmAcc, decltype(a_block_desc_ak0_m_ak1), decltype(b_block_desc_bk0_n_bk1), @@ -827,10 +833,10 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdl_cshuffle_v1 a_block_desc_ak0_m_ak1.GetElementSpaceSize(), max_lds_align); auto a_block_buf = make_dynamic_buffer( - static_cast(p_shared), a_block_desc_ak0_m_ak1.GetElementSpaceSize()); + static_cast(p_shared), a_block_desc_ak0_m_ak1.GetElementSpaceSize()); auto b_block_buf = make_dynamic_buffer( - static_cast(p_shared) + a_block_space_size_aligned, + static_cast(p_shared) + a_block_space_size_aligned, b_block_desc_bk0_n_bk1.GetElementSpaceSize()); constexpr auto a_block_slice_copy_step = make_multi_index(KPerBlock / AK1Number, 0, 0); diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_layernorm_cshuffle_v1.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_layernorm_cshuffle_v1.hpp index 2d4ebe7076408b29cffb863630e903a195ceb6de..0404d88ab8b4faebbe8613453cb4274cf631fe2b 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_layernorm_cshuffle_v1.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_layernorm_cshuffle_v1.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #pragma once @@ -58,7 +58,7 @@ __global__ void const Block2CTileMap block_2_ctile_map) { #if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \ - defined(__gfx940__)) + defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__)) __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()]; // TODO ANT: separate into MMA + Epilogue @@ -173,8 +173,8 @@ struct GridwiseGemmLayernorm_k0mk1_k0nk1_mn_xdl_cshuffle_v1 using ThisThreadBlock = ThisThreadBlock; - using GridwiseGemmPipe = remove_cvref_t())>; + using GridwiseGemmPipe = remove_cvref_t< + decltype(GridwiseGemmPipeline_Selector())>; __host__ __device__ static constexpr auto GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1() { @@ -345,8 +345,9 @@ struct GridwiseGemmLayernorm_k0mk1_k0nk1_mn_xdl_cshuffle_v1 c_grid_desc_m_n); } - using CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock = remove_cvref_t; + using CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock = + remove_cvref_t; using C0GridDescriptor_NBlock_NPerBlock = remove_cvref_t; diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_waveletmodel_cshuffle.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_waveletmodel_cshuffle.hpp index acece0fbba42bc322cb3102aaf8898a77b41b9d0..bbd01a238e5d5f5d071ddec519a08d68e8fef206 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_waveletmodel_cshuffle.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_waveletmodel_cshuffle.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #pragma once @@ -330,8 +330,9 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdl_waveletmodel_cshuffle return e_grid_desc_mblock_mperblock_nblock_nperblock; } - using EGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock = remove_cvref_t; + using EGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock = + remove_cvref_t; using DefaultBlock2ETileMap = remove_cvref_t; diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_bwd_weight.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_bwd_weight.hpp index da7ad1cacf1c7a1a0b2b97c6cdf4b6e87b80a731..0920a17fc7f9cfa6265f1599d6deade702fd53ed 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_bwd_weight.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_bwd_weight.hpp @@ -166,7 +166,7 @@ __global__ void const CBlockClusterAdaptor c_block_cluster_adaptor) { #if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \ - defined(__gfx940__)) + defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__)) __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()]; GridwiseGemm::template Run(p_a_grid, @@ -259,8 +259,8 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_bwd_weight using ThisThreadBlock = ThisThreadBlock; - using GridwiseGemmPipe = remove_cvref_t())>; + using GridwiseGemmPipe = remove_cvref_t< + decltype(GridwiseGemmPipeline_Selector())>; // denorm test fix, required to work around fp16 mfma issue // we convert fp16->fp32->bf16 and execute bf16 mfma instruction diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_skip_b_lds_v1.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_skip_b_lds_v1.hpp index 8d86f3c1d75190bd2c285044c234d17f41e7f7d2..b12bcee0f414c6e56b6550c084630b86ab153bb5 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_skip_b_lds_v1.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_skip_b_lds_v1.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #pragma once @@ -45,7 +45,7 @@ __global__ void const Block2CTileMap block_2_ctile_map) { #if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \ - defined(__gfx940__)) + defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__)) __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()]; GridwiseGemm::template Run(p_a_grid, diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_streamk.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_streamk.hpp new file mode 100644 index 0000000000000000000000000000000000000000..70abcac0b108722c33eb3e58df17c028c56ef8f3 --- /dev/null +++ b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_streamk.hpp @@ -0,0 +1,1184 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck/utility/common_header.hpp" +#include "ck/tensor_description/multi_index_transform_helper.hpp" +#include "ck/tensor_description/tensor_descriptor.hpp" +#include "ck/tensor_description/tensor_descriptor_helper.hpp" +#include "ck/tensor_operation/gpu/grid/block_to_ctile_map.hpp" +#include "ck/tensor_operation/gpu/grid/gridwise_gemm_pipeline_v1.hpp" +#include "ck/tensor_operation/gpu/block/blockwise_gemm_xdlops.hpp" +#include "ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v4r1.hpp" +#include "ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v6r1r2.hpp" +#include "ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" +#include "ck/tensor_operation/gpu/grid/gridwise_gemm_pipeline_v3.hpp" +#include "ck/utility/workgroup_barrier.hpp" +#include "ck/utility/reduction_functions_accumulate.hpp" + +namespace ck { + +template +__global__ void +#if CK_USE_LAUNCH_BOUNDS + __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) +#endif + kernel_gemm_xdlops_streamk(const typename GridwiseGemm::FloatAB* p_a_grid, + const typename GridwiseGemm::FloatAB* p_b_grid, + typename GridwiseGemm::FloatC* p_c_grid, + void* p_workspace, + index_t M, + index_t N, + index_t K, + index_t StrideA, + index_t StrideB, + index_t StrideC, + typename GridwiseGemm::Block2CTileMap block_mapping) +{ +#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \ + defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__)) + constexpr index_t shared_size = GridwiseGemm::GetSharedMemoryNumberOfByte(); + + __shared__ uint8_t p_shared[shared_size]; + + GridwiseGemm::Run(p_a_grid, + p_b_grid, + p_c_grid, + p_workspace, + M, + N, + K, + StrideA, + StrideB, + StrideC, + block_mapping, + static_cast(p_shared)); +#else + ignore = p_a_grid; + ignore = p_b_grid; + ignore = p_c_grid; + ignore = p_workspace; + ignore = M; + ignore = N; + ignore = K; + ignore = StrideA; + ignore = StrideB; + ignore = StrideC; + ignore = block_mapping; +#endif // end of if (defined(__gfx908__) || defined(__gfx90a__)) +} + +template +struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_streamk +{ + static constexpr auto I0 = Number<0>{}; + static constexpr auto I1 = Number<1>{}; + static constexpr auto I2 = Number<2>{}; + static constexpr auto I3 = Number<3>{}; + static constexpr auto I4 = Number<4>{}; + static constexpr auto I5 = Number<5>{}; + static constexpr auto I6 = Number<6>{}; + static constexpr auto I7 = Number<7>{}; + + // K1 should be Number<...> + static constexpr auto K1 = Number{}; + static constexpr auto M01 = 1; + static constexpr auto N01 = 1; + static constexpr auto KPerBlock = K0PerBlock * K1; + + using ThisThreadBlock = ThisThreadBlock; + using FloatAcc = FloatAcc_; + using FloatCShuffle = FloatAcc; + + using Block2CTileMap = Block2CTileMap_; + using FloatAB = FloatAB_; + using FloatC = FloatC_; + + struct Argument : public ck::tensor_operation::device::BaseArgument + { + const FloatAB* p_a_grid; + const FloatAB* p_b_grid; + FloatC* p_c_grid; + index_t M; + index_t N; + index_t K; + index_t StrideA; + index_t StrideB; + index_t StrideC; + Block2CTileMap block_mapping; + + Argument(const FloatAB* p_a_grid_, + const FloatAB* p_b_grid_, + FloatC* p_c_grid_, + index_t M_, + index_t N_, + index_t K_, + index_t StrideA_, + index_t StrideB_, + index_t StrideC_, + uint32_t num_cu, + uint32_t occupancy, + uint32_t num_sk_blocks_) + : p_a_grid(p_a_grid_), + p_b_grid(p_b_grid_), + p_c_grid(p_c_grid_), + M(M_), + N(N_), + K(K_), + StrideA(StrideA_), + StrideB(StrideB_), + StrideC(StrideC_), + block_mapping(M, N, K, num_cu, occupancy, num_sk_blocks_) + { + } + + void Print() const + { + std::cout << "arg {" + << "M:" << M << ", " + << "N:" << N << ", " + << "K:" << K << ", " + << "SA:" << StrideA << ", " + << "SB:" << StrideB << ", " + << "SC:" << StrideC << std::endl; + } + }; + + __host__ __device__ static auto CalculateGridSize(const Argument& karg) + { + return std::make_tuple(math::integer_divide_ceil(karg.N, NPerBlock), + math::integer_divide_ceil(karg.M, MPerBlock), + karg.k_batch); + } + + __host__ __device__ static auto CalculateK0(index_t KPad) { return KPad / K1; } + + __host__ __device__ static auto + MakeAGridDescriptor_K0_M_K1(index_t M, index_t MPad, index_t K, index_t KPad, index_t StrideA) + { + const index_t K0 = CalculateK0(KPad); + + const auto a_grid_desc_m_k = [&]() { + if constexpr(is_same::value) + { + return make_naive_tensor_descriptor(make_tuple(M, K), make_tuple(StrideA, I1)); + } + else if constexpr(is_same::value) + { + return make_naive_tensor_descriptor(make_tuple(M, K), make_tuple(I1, StrideA)); + } + }(); + + const auto a_grid_desc_m_kpad = transform_tensor_descriptor( + a_grid_desc_m_k, + make_tuple(make_pass_through_transform(M), make_right_pad_transform(K, KPad - K)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + + return transform_tensor_descriptor(a_grid_desc_m_kpad, + make_tuple(make_unmerge_transform(make_tuple(K0, K1)), + make_right_pad_transform(M, MPad - M)), + make_tuple(Sequence<1>{}, Sequence<0>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + } + + __host__ __device__ static auto + MakeBGridDescriptor_K0_N_K1(index_t K, index_t KPad, index_t N, index_t NPad, index_t StrideB) + { + const index_t K0 = CalculateK0(KPad); + + const auto b_grid_desc_k_n = [&]() { + if constexpr(is_same::value) + { + return make_naive_tensor_descriptor(make_tuple(K, N), make_tuple(StrideB, I1)); + } + else if constexpr(is_same::value) + { + return make_naive_tensor_descriptor(make_tuple(K, N), make_tuple(I1, StrideB)); + } + }(); + + const auto b_grid_desc_kpad_n = transform_tensor_descriptor( + b_grid_desc_k_n, + make_tuple(make_right_pad_transform(K, KPad - K), make_pass_through_transform(N)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + + return transform_tensor_descriptor(b_grid_desc_kpad_n, + make_tuple(make_unmerge_transform(make_tuple(K0, K1)), + make_right_pad_transform(N, NPad - N)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + } + + __host__ __device__ static auto + MakeCGridDescriptor_M_N(index_t M, index_t MPad, index_t N, index_t NPad, index_t StrideC) + { + const auto c_grid_desc_m_n = [&]() { + if constexpr(is_same::value) + { + return make_naive_tensor_descriptor(make_tuple(M, N), make_tuple(StrideC, I1)); + } + else if constexpr(is_same::value) + { + return make_naive_tensor_descriptor(make_tuple(M, N), make_tuple(I1, StrideC)); + } + }(); + + return transform_tensor_descriptor(c_grid_desc_m_n, + make_tuple(make_right_pad_transform(M, MPad - M), + make_right_pad_transform(N, NPad - N)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + } + + __host__ __device__ static constexpr auto GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1() + { + // A matrix in LDS memory, dst of blockwise copy + return make_naive_tensor_descriptor( + make_tuple(Number{}, Number{}, K1), + make_tuple(Number{} * K1, K1, I1)); + } + + __host__ __device__ static constexpr auto GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1() + { + // B matrix in LDS memory, dst of blockwise copy + return make_naive_tensor_descriptor( + make_tuple(Number{}, Number{}, K1), + make_tuple(Number{} * K1, K1, I1)); + } + + __host__ __device__ static constexpr index_t GetSharedMemoryNumberOfByte() + { + constexpr auto max_lds_align = K1; + + // LDS allocation for A and B: be careful of alignment + constexpr auto a_block_desc_k0_m_k1 = GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1(); + constexpr auto b_block_desc_k0_n_k1 = GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1(); + + constexpr auto a_block_space_size_aligned = + math::integer_least_multiple(a_block_desc_k0_m_k1.GetElementSpaceSize(), max_lds_align); + + constexpr auto b_block_space_size_aligned = + math::integer_least_multiple(b_block_desc_k0_n_k1.GetElementSpaceSize(), max_lds_align); + + constexpr auto c_block_size = + GetCBlockDescriptor_MBlock_MPerShuffle_NBlock_NPerShuffle().GetElementSpaceSize(); + + return math::max((a_block_space_size_aligned + b_block_space_size_aligned) * + sizeof(FloatAB), + c_block_size * sizeof(FloatCShuffle)); + } + + __host__ __device__ static constexpr bool CheckValidity(const Argument& karg) + { + if constexpr(is_same::value) + { + if(karg.K % ABlockTransferSrcScalarPerVector != 0) + return false; + } + else + { + if(karg.M % ABlockTransferSrcScalarPerVector != 0) + return false; + } + + if constexpr(is_same::value) + { + if(karg.N % BBlockTransferSrcScalarPerVector != 0) + return false; + } + else + { + if(karg.K % BBlockTransferSrcScalarPerVector != 0) + return false; + } + + if constexpr(is_same::value) + { + if(karg.N % CBlockTransferScalarPerVector_NWaveNPerXDL != 0) + return false; + } + else + { + if(karg.M % CBlockTransferScalarPerVector_NWaveNPerXDL != 0) + return false; + } + + return true; + } + + __host__ __device__ static constexpr bool CalculateHasMainK0BlockLoop(index_t K0) + { + const bool has_main_k0_block_loop = K0 > K0PerBlock; + + return has_main_k0_block_loop; + } + + template + __host__ __device__ static constexpr auto + MakeCGridDesc_MBlock_MPerBlock_NBlock_NPerBlock(const CGridDesc& c_m_n_grid_desc) + { + const auto M = c_m_n_grid_desc.GetLength(I0); + const auto N = c_m_n_grid_desc.GetLength(I1); + + const auto MBlock = M / MPerBlock; + const auto NBlock = N / NPerBlock; + + return transform_tensor_descriptor( + c_m_n_grid_desc, + make_tuple(make_unmerge_transform(make_tuple(MBlock, Number{})), + make_unmerge_transform(make_tuple(NBlock, Number{}))), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0, 1>{}, Sequence<2, 3>{})); + } + + // return block_id to C matrix tile idx (m0, n0) mapping + template + __host__ __device__ static constexpr auto MakeCBlockClusterAdaptor( + const CGridDesc& c_m_n_grid_desc, index_t /* M01 */, index_t /* N01 */, index_t KBatch) + { + return BlockToCTileMap_KSplit_M00_N0_M01Adapt( + c_m_n_grid_desc, 8, KBatch); + } + + __host__ __device__ static constexpr auto + GetCBlockDescriptor_MBlock_MPerShuffle_NBlock_NPerShuffle() + { + constexpr index_t MWave = MPerBlock / (MRepeat * MPerXDL); + constexpr index_t NWave = NPerBlock / (NRepeat * NPerXDL); + + return make_naive_tensor_descriptor_packed( + make_tuple(I1, + Number{}, + I1, + Number{})); + } + + __host__ __device__ static constexpr auto + GetCBlockDescriptor_MShuffleRepeat_MPerShuffle_NShuffleRepeat_NPerShuffle() + { + constexpr index_t MWave = MPerBlock / (MRepeat * MPerXDL); + constexpr index_t NWave = NPerBlock / (NRepeat * NPerXDL); + + return make_naive_tensor_descriptor_packed( + make_tuple(Number{}, + Number{}, + Number{}, + Number{})); + } + + __host__ __device__ static constexpr auto GetClusterLengthReduction() + { + // TODO: assume C is row major + // TODO: we always first loop over N, then M + constexpr auto NPerBlockPow2 = math::next_power_of_two(); + constexpr auto NPerBlockReduction = + NPerBlockPow2 / CBlockTransferScalarPerVector_NWaveNPerXDL; + constexpr auto MPerBlockReduction = + (BlockSize + NPerBlockReduction - 1) / NPerBlockReduction; + return Sequence{}; + } + + __host__ __device__ static constexpr auto GetPartialAccBlockDescriptor() + { + const auto c_partial_acc_block_m_n = [&]() { + if constexpr(is_same::value) + { + return make_naive_tensor_descriptor(make_tuple(MPerBlock, NPerBlock), + make_tuple(NPerBlock, I1)); + } + else if constexpr(is_same::value) + { + return make_naive_tensor_descriptor(make_tuple(MPerBlock, NPerBlock), + make_tuple(I1, MPerBlock)); + } + }(); + return c_partial_acc_block_m_n; + } + + using CGridDesc_M_N = remove_cvref_t; + + __device__ static void Run(const FloatAB* p_a_grid, + const FloatAB* p_b_grid, + FloatC* p_c_grid, + void* p_workspace, + index_t M, + index_t N, + index_t K, + index_t StrideA, + index_t StrideB, + index_t StrideC, + Block2CTileMap block_mapping, + void* __restrict__ p_shared_block) + { + uint32_t m = M; + uint32_t n = N; + uint32_t k = K; + uint32_t pad_m = (m + MPerBlock - 1) / MPerBlock * MPerBlock; + uint32_t pad_n = (n + NPerBlock - 1) / NPerBlock * NPerBlock; + uint32_t pad_k = (k + KPerBlock - 1) / KPerBlock * KPerBlock; + uint32_t stride_a = StrideA; + uint32_t stride_b = StrideB; + uint32_t stride_c = StrideC; + + const auto a_k0_m_k1_grid_desc = MakeAGridDescriptor_K0_M_K1(m, pad_m, k, pad_k, stride_a); + const auto b_k0_n_k1_grid_desc = MakeBGridDescriptor_K0_N_K1(k, pad_k, n, pad_n, stride_b); + const auto c_grid_desc_m_n = MakeCGridDescriptor_M_N(m, pad_m, n, pad_n, stride_c); + + const auto c_grid_desc_mblock_mperblock_nblock_nperblock = + MakeCGridDesc_MBlock_MPerBlock_NBlock_NPerBlock(c_grid_desc_m_n); + const AElementwiseOperation a_element_op = AElementwiseOperation{}; + const BElementwiseOperation b_element_op = BElementwiseOperation{}; + const CElementwiseOperation c_element_op = CElementwiseOperation{}; + + const auto a_grid_buf = make_dynamic_buffer( + p_a_grid, a_k0_m_k1_grid_desc.GetElementSpaceSize()); + const auto b_grid_buf = make_dynamic_buffer( + p_b_grid, b_k0_n_k1_grid_desc.GetElementSpaceSize()); + auto c_grid_buf = make_dynamic_buffer( + p_c_grid, c_grid_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize()); + + // lds max alignment + constexpr auto max_lds_align = K1; + + // A matrix in LDS memory, dst of blockwise copy + constexpr auto a_block_desc_k0_m_k1 = GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1(); + + // B matrix in LDS memory, dst of blockwise copy + constexpr auto b_block_desc_k0_n_k1 = GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1(); + + auto blockwise_gemm = + BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1{}; + + auto c_thread_buf = blockwise_gemm.GetCThreadBuffer(); + + // LDS allocation for A and B: be careful of alignment + constexpr auto a_block_space_size = + math::integer_least_multiple(a_block_desc_k0_m_k1.GetElementSpaceSize(), max_lds_align); + + FloatAB* p_a_block = static_cast(p_shared_block); + FloatAB* p_b_block = static_cast(p_shared_block) + a_block_space_size; + + constexpr auto a_block_slice_copy_step = make_multi_index(K0PerBlock, 0, 0); + constexpr auto b_block_slice_copy_step = make_multi_index(K0PerBlock, 0, 0); + + auto a_block_buf = make_dynamic_buffer( + p_a_block, a_block_desc_k0_m_k1.GetElementSpaceSize()); + auto b_block_buf = make_dynamic_buffer( + p_b_block, b_block_desc_k0_n_k1.GetElementSpaceSize()); + + // gridwise GEMM pipeline + const auto gridwise_gemm_pipeline = GridwiseGemmPipeline_v3(); + + uint32_t block_idx = block_mapping.get_block_idx(); + bool is_sk_block = block_idx < block_mapping.sk_num_blocks; + bool is_dp_block = block_idx >= block_mapping.dp_start_block_idx && + block_idx < block_mapping.reduction_start_block_idx; + bool is_reduction_block = block_idx >= block_mapping.reduction_start_block_idx; + bool is_padding_block = block_idx >= block_mapping.sk_num_blocks && + block_idx < block_mapping.dp_start_block_idx; + uint32_t iter_start, iter_end; + block_mapping.get_block_itr(block_idx, iter_start, iter_end); + uint32_t total_iter_length = iter_end - iter_start; + + if(is_padding_block) + return; + + uint32_t* p_semaphore = + reinterpret_cast(reinterpret_cast(p_workspace) + + block_mapping.get_workspace_size_for_acc(sizeof(FloatAcc))); + + if constexpr(Block2CTileMap::ReductionStrategy == StreamKReductionStrategy::Reduction) + { + if(is_reduction_block) + { + // descriptors + constexpr auto cluster_length_reduce = GetClusterLengthReduction(); + constexpr auto reduce_desc = make_cluster_descriptor(cluster_length_reduce); + const auto reduce_thread_cluster_idx = + reduce_desc.CalculateBottomIndex(make_multi_index(get_thread_local_1d_id())); + const auto thread_m_cluster_id = reduce_thread_cluster_idx[I0]; + const auto thread_n_cluster_id = reduce_thread_cluster_idx[I1]; + + constexpr auto MReduceIters = + math::integer_divide_ceil(Number{}, cluster_length_reduce.At(I0)); + constexpr auto NReduceIters = math::integer_divide_ceil( + Number{}, + cluster_length_reduce.At(I1) * + Number{}); + + constexpr auto acc_thread_buf_load_desc = make_naive_tensor_descriptor_packed( + make_tuple(I1, Number{})); + constexpr auto acc_thread_buf_store_desc = make_naive_tensor_descriptor_packed( + make_tuple(I1, I1, I1, Number{})); + + constexpr auto c_partial_acc_block_m_n = GetPartialAccBlockDescriptor(); + + constexpr auto partial_acc_load_step_n = make_multi_index( + 0, cluster_length_reduce.At(I1) * CBlockTransferScalarPerVector_NWaveNPerXDL); + constexpr auto partial_acc_load_step_n_reverse = + make_multi_index(0, + -1 * cluster_length_reduce.At(I1).value * (NReduceIters - 1) * + CBlockTransferScalarPerVector_NWaveNPerXDL); + constexpr auto partial_acc_load_step_m = + make_multi_index(cluster_length_reduce.At(I0), 0); + + constexpr auto partial_acc_store_step_n = make_multi_index( + 0, + 0, + 0, + cluster_length_reduce.At(I1) * CBlockTransferScalarPerVector_NWaveNPerXDL); + constexpr auto partial_acc_store_step_n_reverse = + make_multi_index(0, + 0, + 0, + -1 * cluster_length_reduce.At(I1).value * (NReduceIters - 1) * + CBlockTransferScalarPerVector_NWaveNPerXDL); + constexpr auto partial_acc_store_step_m = + make_multi_index(0, cluster_length_reduce.At(I0), 0, 0); + + StaticBuffer + parcial_acc_buf; + StaticBuffer + acc_buf; + + // start to compute + auto reduction_idx = blockIdx.x - block_mapping.reduction_start_block_idx; + auto spatial_idx = block_mapping.tile_to_spatial(reduction_idx, m, n); + + workgroup_barrier wg_barrier(p_semaphore); + + uint32_t tile_acc_offset_start = + block_mapping.get_acc_buffer_offset_from_tile(reduction_idx); + uint32_t tile_acc_offset_end = + block_mapping.get_acc_buffer_offset_from_tile(reduction_idx + 1); + + auto acc_load = ThreadwiseTensorSliceTransfer_v2< + FloatAcc, // SrcData, + FloatAcc, // DstData, + decltype(c_partial_acc_block_m_n), // SrcDesc, + decltype(acc_thread_buf_load_desc), // DstDesc, + Sequence<1, CBlockTransferScalarPerVector_NWaveNPerXDL>, // SliceLengths, + Sequence<0, 1>, // DimAccessOrder, + 1, // SrcVectorDim, + CBlockTransferScalarPerVector_NWaveNPerXDL, // SrcScalarPerVector, + 1, // SrcScalarStrideInVector, + false // SrcResetCoordinateAfterRun, + >{c_partial_acc_block_m_n, + make_multi_index(thread_m_cluster_id, + thread_n_cluster_id * + CBlockTransferScalarPerVector_NWaveNPerXDL)}; + + auto acc_store = ThreadwiseTensorSliceTransfer_v1r3< + FloatAcc, // SrcData, + FloatC, // DstData, + decltype(acc_thread_buf_store_desc), // SrcDesc, + decltype(c_grid_desc_mblock_mperblock_nblock_nperblock), // DstDesc, + CElementwiseOperation, // ElementwiseOperation, + Sequence<1, 1, 1, CBlockTransferScalarPerVector_NWaveNPerXDL>, // SliceLengths, + Sequence<0, 1, 2, 3>, // DimAccessOrder, + 3, // DstVectorDim, + CBlockTransferScalarPerVector_NWaveNPerXDL, // DstScalarPerVector, + InMemoryDataOperationEnum::Set, // InMemoryDataOperationEnum DstInMemOp, + 1, // DstScalarStrideInVector, + false // DstResetCoordinateAfterRun, + >{c_grid_desc_mblock_mperblock_nblock_nperblock, + make_multi_index(__builtin_amdgcn_readfirstlane(spatial_idx[I0]), + thread_m_cluster_id, + __builtin_amdgcn_readfirstlane(spatial_idx[I1]), + thread_n_cluster_id * + CBlockTransferScalarPerVector_NWaveNPerXDL), + CElementwiseOperation{}}; + + // block synchronization + wg_barrier.wait_eq(reduction_idx, tile_acc_offset_end - tile_acc_offset_start); + +#if 0 + if(threadIdx.x == 0) { + printf("bid:%d, rid:%d, os:%d,%d, spatial:%d,%d\n", static_cast(blockIdx.x), + reduction_idx, __builtin_amdgcn_readfirstlane(tile_acc_offset_start), __builtin_amdgcn_readfirstlane(tile_acc_offset_end), + __builtin_amdgcn_readfirstlane(spatial_idx[I0]), + __builtin_amdgcn_readfirstlane(spatial_idx[I1])); + } +#endif + + using Accumulation = ck::detail:: + AccumulateWithNanCheck; + + for(int i_m = 0; i_m < MReduceIters; i_m++) + { + static_for<0, NReduceIters, 1>{}([&](auto i_n_reduce) { + acc_buf.Clear(); + for(auto i = tile_acc_offset_start; i < tile_acc_offset_end; i++) + { + auto c_partial_acc_buf = + make_dynamic_buffer( + reinterpret_cast(p_workspace) + + i * c_partial_acc_block_m_n.GetElementSpaceSize(), + c_partial_acc_block_m_n.GetElementSpaceSize()); + + acc_load.Run(c_partial_acc_block_m_n, + c_partial_acc_buf, + acc_thread_buf_load_desc, + make_tuple(I0, I0), + parcial_acc_buf); + + static_for<0, CBlockTransferScalarPerVector_NWaveNPerXDL, 1>{}( + [&](auto i_vec) { + constexpr auto offset = + acc_thread_buf_load_desc.CalculateOffset( + make_tuple(0, i_vec)); + Accumulation::Calculate(acc_buf(Number{}), + parcial_acc_buf[Number{}]); + }); + } + + if(thread_n_cluster_id * CBlockTransferScalarPerVector_NWaveNPerXDL < + NPerBlock) + { + acc_store.Run(acc_thread_buf_store_desc, + make_tuple(I0, I0, I0, I0), + acc_buf, + c_grid_desc_mblock_mperblock_nblock_nperblock, + c_grid_buf); + } + if constexpr(NReduceIters != 1) + { + if constexpr(i_n_reduce != (NReduceIters - 1)) + { + acc_load.MoveSrcSliceWindow(c_partial_acc_block_m_n, + partial_acc_load_step_n); + acc_store.MoveDstSliceWindow( + c_grid_desc_mblock_mperblock_nblock_nperblock, + partial_acc_store_step_n); + } + else + { + acc_load.MoveSrcSliceWindow(c_partial_acc_block_m_n, + partial_acc_load_step_n_reverse); + acc_store.MoveDstSliceWindow( + c_grid_desc_mblock_mperblock_nblock_nperblock, + partial_acc_store_step_n_reverse); + } + } + }); + { + acc_load.MoveSrcSliceWindow(c_partial_acc_block_m_n, + partial_acc_load_step_m); + acc_store.MoveDstSliceWindow(c_grid_desc_mblock_mperblock_nblock_nperblock, + partial_acc_store_step_m); + } + } + return; + } + } + + // offset for last acc buffer of this block + uint32_t block_acc_offset = + (block_mapping.get_acc_buffer_offset_from_block(block_idx + 1) - 1) * MPerBlock * + NPerBlock; + + while(true) + { + uint32_t current_iter_length = __builtin_amdgcn_readfirstlane( + block_mapping.get_current_iter_length(iter_start, iter_end, total_iter_length)); + uint32_t tile_idx, iter_offset; + block_mapping.get_tile_idx_with_offset(iter_end - 1, tile_idx, iter_offset); + iter_offset = __builtin_amdgcn_readfirstlane(iter_offset - current_iter_length + 1); + auto spatial_idx = block_mapping.tile_to_spatial(tile_idx, m, n); + + const index_t m_block_data_idx_on_grid = + __builtin_amdgcn_readfirstlane(spatial_idx[I0] * MPerBlock); + + const index_t n_block_data_idx_on_grid = + __builtin_amdgcn_readfirstlane(spatial_idx[I1] * NPerBlock); + + const index_t k0_block_data_idx_on_grid = + __builtin_amdgcn_readfirstlane(iter_offset * K0PerBlock); + + // A matrix blockwise copy + auto a_blockwise_copy = + ThreadGroupTensorSliceTransfer_v4r1, + ABlockTransferThreadClusterLengths_K0_M_K1, + ABlockTransferThreadClusterArrangeOrder, + FloatAB, + FloatAB, + decltype(a_k0_m_k1_grid_desc), + decltype(a_block_desc_k0_m_k1), + ABlockTransferSrcAccessOrder, + Sequence<1, 0, 2>, + ABlockTransferSrcVectorDim, + 2, + ABlockTransferSrcScalarPerVector, + ABlockTransferDstScalarPerVector_K1, + 1, + 1, + AThreadTransferSrcResetCoordinateAfterRun, + true>( + a_k0_m_k1_grid_desc, + make_multi_index(k0_block_data_idx_on_grid, m_block_data_idx_on_grid, 0), + a_element_op, + a_block_desc_k0_m_k1, + make_multi_index(0, 0, 0), + ck::tensor_operation::element_wise::PassThrough{}); + + // B matrix blockwise copy + auto b_blockwise_copy = + ThreadGroupTensorSliceTransfer_v4r1, + BBlockTransferThreadClusterLengths_K0_N_K1, + BBlockTransferThreadClusterArrangeOrder, + FloatAB, + FloatAB, + decltype(b_k0_n_k1_grid_desc), + decltype(b_block_desc_k0_n_k1), + BBlockTransferSrcAccessOrder, + Sequence<1, 0, 2>, + BBlockTransferSrcVectorDim, + 2, + BBlockTransferSrcScalarPerVector, + BBlockTransferDstScalarPerVector_K1, + 1, + 1, + BThreadTransferSrcResetCoordinateAfterRun, + true>( + b_k0_n_k1_grid_desc, + make_multi_index(k0_block_data_idx_on_grid, n_block_data_idx_on_grid, 0), + b_element_op, + b_block_desc_k0_n_k1, + make_multi_index(0, 0, 0), + ck::tensor_operation::element_wise::PassThrough{}); + + const index_t num_k_block_main_loop = current_iter_length; + + gridwise_gemm_pipeline.Run(a_k0_m_k1_grid_desc, + a_block_desc_k0_m_k1, + a_blockwise_copy, + a_grid_buf, + a_block_buf, + a_block_slice_copy_step, + b_k0_n_k1_grid_desc, + b_block_desc_k0_n_k1, + b_blockwise_copy, + b_grid_buf, + b_block_buf, + b_block_slice_copy_step, + blockwise_gemm, + c_thread_buf, + num_k_block_main_loop); + + // output: register to global memory + { + constexpr index_t MWave = MPerBlock / (MRepeat * MPerXDL); + constexpr index_t NWave = NPerBlock / (NRepeat * NPerXDL); + + constexpr auto c_m0_n0_m1_n1_m2_m3_m4_n2_block_desc = + blockwise_gemm.GetCBlockDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(); + + constexpr auto c_m0_n0_m1_n1_m2_m3_m4_n2_thread_desc = + blockwise_gemm.GetCThreadDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(); + + constexpr auto M0 = c_m0_n0_m1_n1_m2_m3_m4_n2_block_desc.GetLength(I0); + constexpr auto N0 = c_m0_n0_m1_n1_m2_m3_m4_n2_block_desc.GetLength(I1); + constexpr auto M1 = c_m0_n0_m1_n1_m2_m3_m4_n2_block_desc.GetLength(I2); + constexpr auto N1 = c_m0_n0_m1_n1_m2_m3_m4_n2_block_desc.GetLength(I3); + constexpr auto M2 = c_m0_n0_m1_n1_m2_m3_m4_n2_block_desc.GetLength(I4); + constexpr auto M3 = c_m0_n0_m1_n1_m2_m3_m4_n2_block_desc.GetLength(I5); + constexpr auto M4 = c_m0_n0_m1_n1_m2_m3_m4_n2_block_desc.GetLength(I6); + constexpr auto N2 = c_m0_n0_m1_n1_m2_m3_m4_n2_block_desc.GetLength(I7); + + constexpr auto c_block_desc_mblock_mpershuffle_nblock_npershuffle = + GetCBlockDescriptor_MBlock_MPerShuffle_NBlock_NPerShuffle(); + + constexpr auto c_block_desc_mshuffle_mpershuffle_nshuffle_npershuffle = + GetCBlockDescriptor_MShuffleRepeat_MPerShuffle_NShuffleRepeat_NPerShuffle(); + + auto c_block_buf = make_dynamic_buffer( + reinterpret_cast(p_shared_block), + c_block_desc_mblock_mpershuffle_nblock_npershuffle.GetElementSpaceSize()); + + auto c_partial_acc_buf = + make_dynamic_buffer( + reinterpret_cast(p_workspace) + block_acc_offset, + c_block_desc_mshuffle_mpershuffle_nshuffle_npershuffle + .GetElementSpaceSize()); + + constexpr auto c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2 = transform_tensor_descriptor( + c_block_desc_mblock_mpershuffle_nblock_npershuffle, + make_tuple(make_freeze_transform(I0), // freeze mblock + make_unmerge_transform( + make_tuple(CShuffleMRepeatPerShuffle, + M1, + M2, + M3, + M4)), // M1 = MWave, M2 * M3 * M4 = MPerXDL + make_freeze_transform(I0), // freeze nblock + make_unmerge_transform( + make_tuple(CShuffleNRepeatPerShuffle, + N1, + N2))), // M1 = MWave, M2 * M3 * M4 = MPerXDL + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), + make_tuple(Sequence<>{}, + Sequence<0, 2, 4, 5, 6>{}, + Sequence<>{}, + Sequence<1, 3, 7>{})); + + // calculate origin of thread output tensor on global memory + // blockwise GEMM c matrix starting index + const auto c_thread_mtx_on_block = + blockwise_gemm.CalculateCThreadOriginDataIndex(I0, I0, I0, I0); + + const index_t m_thread_data_on_block = c_thread_mtx_on_block[I0]; + const index_t n_thread_data_on_block = c_thread_mtx_on_block[I1]; + + const auto m_thread_data_on_block_to_m0_m1_m2_m3_m4_adaptor = + make_single_stage_tensor_adaptor( + make_tuple(make_merge_transform(make_tuple(M0, M1, M2, M3, M4))), + make_tuple(Sequence<0, 1, 2, 3, 4>{}), + make_tuple(Sequence<0>{})); + + const auto m_thread_data_on_block_idx = + m_thread_data_on_block_to_m0_m1_m2_m3_m4_adaptor.CalculateBottomIndex( + make_multi_index(m_thread_data_on_block)); + + const auto n_thread_data_on_block_to_n0_n1_n2_adaptor = + make_single_stage_tensor_adaptor( + make_tuple(make_merge_transform(make_tuple(N0, N1, N2))), + make_tuple(Sequence<0, 1, 2>{}), + make_tuple(Sequence<0>{})); + + const auto n_thread_data_on_block_idx = + n_thread_data_on_block_to_n0_n1_n2_adaptor.CalculateBottomIndex( + make_multi_index(n_thread_data_on_block)); + + // VGPR to LDS + auto c_thread_copy_vgpr_to_lds = ThreadwiseTensorSliceTransfer_v1r3< + FloatAcc, + FloatCShuffle, + decltype(c_m0_n0_m1_n1_m2_m3_m4_n2_thread_desc), + decltype(c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2), + ck::tensor_operation::element_wise::PassThrough, + Sequence, + Sequence<0, 1, 2, 3, 4, 5, 6, 7>, + 7, + 1, + InMemoryDataOperationEnum::Set, + 1, + true>{c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2, + make_multi_index(0, + 0, + m_thread_data_on_block_idx[I1], + n_thread_data_on_block_idx[I1], + m_thread_data_on_block_idx[I2], + m_thread_data_on_block_idx[I3], + m_thread_data_on_block_idx[I4], + n_thread_data_on_block_idx[I2]), + ck::tensor_operation::element_wise::PassThrough{}}; + + // LDS to global + auto c_block_copy_lds_to_global = ThreadGroupTensorSliceTransfer_v6r1r2< + ThisThreadBlock, // index_t BlockSize, + CElementwiseOperation, // ElementwiseOperation, + // InMemoryDataOperationEnum::Set, // DstInMemOp, + Sequence<1, + CShuffleMRepeatPerShuffle * MWave * MPerXDL, + 1, + CShuffleNRepeatPerShuffle * NWave * NPerXDL>, // BlockSliceLengths, + CBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, + Sequence<0, 1, 2, 3>, // typename ThreadClusterArrangeOrder, + FloatCShuffle, // typename SrcData, + FloatC, // typename DstData, + decltype(c_block_desc_mblock_mpershuffle_nblock_npershuffle), + decltype(c_grid_desc_mblock_mperblock_nblock_nperblock), + Sequence<0, 1, 2, 3>, // typename DimAccessOrder, + 3, // index_t VectorDim, + CBlockTransferScalarPerVector_NWaveNPerXDL, // index_t ScalarPerVector, + false, // bool ThreadTransferSrcResetCoordinateAfterRun, + false> // bool ThreadTransferDstResetCoordinateAfterRun + {c_block_desc_mblock_mpershuffle_nblock_npershuffle, + make_multi_index(0, 0, 0, 0), + c_grid_desc_mblock_mperblock_nblock_nperblock, + make_multi_index(__builtin_amdgcn_readfirstlane(spatial_idx[I0]), + 0, + __builtin_amdgcn_readfirstlane(spatial_idx[I1]), + 0), + c_element_op}; + + // LDS to global partial acc + auto c_block_copy_lds_to_partial_acc = ThreadGroupTensorSliceTransfer_v6r1r2< + ThisThreadBlock, // index_t BlockSize, + CElementwiseOperation, // ElementwiseOperation, + // InMemoryDataOperationEnum::Set, // DstInMemOp, + Sequence<1, + CShuffleMRepeatPerShuffle * MWave * MPerXDL, + 1, + CShuffleNRepeatPerShuffle * NWave * NPerXDL>, // BlockSliceLengths, + CBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, + Sequence<0, 1, 2, 3>, // typename ThreadClusterArrangeOrder, + FloatCShuffle, // typename SrcData, + FloatCShuffle, // typename DstData, + decltype(c_block_desc_mblock_mpershuffle_nblock_npershuffle), + decltype(c_block_desc_mshuffle_mpershuffle_nshuffle_npershuffle), + Sequence<0, 1, 2, 3>, // typename DimAccessOrder, + 3, // index_t VectorDim, + CBlockTransferScalarPerVector_NWaveNPerXDL, // index_t ScalarPerVector, + false, // bool ThreadTransferSrcResetCoordinateAfterRun, => need to be false, + // othre wise has scratch + false> // bool ThreadTransferDstResetCoordinateAfterRun, => need to be false, + // othre wise has scratch + {c_block_desc_mblock_mpershuffle_nblock_npershuffle, + make_multi_index(0, 0, 0, 0), + c_block_desc_mshuffle_mpershuffle_nshuffle_npershuffle, + make_multi_index(0, 0, 0, 0), + c_element_op}; + + constexpr auto mxdlperwave_forward_step = + make_multi_index(0, CShuffleMRepeatPerShuffle * MWave * MPerXDL, 0, 0); + constexpr auto nxdlperwave_forward_step = + make_multi_index(0, 0, 0, CShuffleNRepeatPerShuffle * NWave * NPerXDL); + constexpr auto nxdlperwave_backward_step = + make_multi_index(0, 0, 0, -CShuffleNRepeatPerShuffle * NWave * NPerXDL); + + static_for<0, MRepeat, CShuffleMRepeatPerShuffle>{}([&](auto mxdlperwave_iter) { + constexpr auto mxdlperwave = mxdlperwave_iter; + + static_for<0, NRepeat, CShuffleNRepeatPerShuffle>{}([&](auto nxdlperwave_iter) { + constexpr bool nxdlperwave_forward_sweep = + (mxdlperwave % (2 * CShuffleMRepeatPerShuffle) == 0); + + constexpr index_t nxdlperwave_value = + nxdlperwave_forward_sweep + ? nxdlperwave_iter + : (NRepeat - nxdlperwave_iter - CShuffleNRepeatPerShuffle); + + constexpr auto nxdlperwave = Number{}; + + // make sure it's safe to do ds_write + block_sync_lds(); + + // VGPR to LDS + c_thread_copy_vgpr_to_lds.Run( + c_m0_n0_m1_n1_m2_m3_m4_n2_thread_desc, + make_tuple(mxdlperwave, nxdlperwave, I0, I0, I0, I0, I0, I0), + c_thread_buf, + c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2, + c_block_buf); + + // make sure it's safe to do ds_read + block_sync_lds(); + + c_block_copy_lds_to_global.SetSrcSliceOrigin( + c_block_desc_mblock_mpershuffle_nblock_npershuffle, + make_tuple(0, 0, 0, 0)); + + // LDS to global + if(is_dp_block) + c_block_copy_lds_to_global.template Run( + c_block_desc_mblock_mpershuffle_nblock_npershuffle, + c_block_buf, + c_grid_desc_mblock_mperblock_nblock_nperblock, + c_grid_buf); + else if(is_sk_block) + { + if constexpr(Block2CTileMap::ReductionStrategy == + StreamKReductionStrategy::Reduction) + { + // constexpr offset + c_block_copy_lds_to_partial_acc.SetSrcSliceOrigin( + c_block_desc_mblock_mpershuffle_nblock_npershuffle, + make_tuple(0, 0, 0, 0)); + + c_block_copy_lds_to_partial_acc.SetDstSliceOrigin( + c_block_desc_mshuffle_mpershuffle_nshuffle_npershuffle, + make_tuple(mxdlperwave.value, 0, nxdlperwave.value, 0)); + + c_block_copy_lds_to_partial_acc + .template Run( + c_block_desc_mblock_mpershuffle_nblock_npershuffle, + c_block_buf, + c_block_desc_mshuffle_mpershuffle_nshuffle_npershuffle, + c_partial_acc_buf); + } + else if constexpr(Block2CTileMap::ReductionStrategy == + StreamKReductionStrategy::Atomic) + { + c_block_copy_lds_to_global + .template Run( + c_block_desc_mblock_mpershuffle_nblock_npershuffle, + c_block_buf, + c_grid_desc_mblock_mperblock_nblock_nperblock, + c_grid_buf); + } + } + + // move on nxdlperwave dimension + if constexpr(nxdlperwave_forward_sweep && + (nxdlperwave < NRepeat - CShuffleNRepeatPerShuffle)) + { + c_block_copy_lds_to_global.MoveDstSliceWindow( + c_grid_desc_mblock_mperblock_nblock_nperblock, + nxdlperwave_forward_step); + } + else if constexpr((!nxdlperwave_forward_sweep) && (nxdlperwave > 0)) + { + c_block_copy_lds_to_global.MoveDstSliceWindow( + c_grid_desc_mblock_mperblock_nblock_nperblock, + nxdlperwave_backward_step); + } + }); + + // move on mxdlperwave dimension + if constexpr(mxdlperwave < MRepeat - CShuffleMRepeatPerShuffle) + { + c_block_copy_lds_to_global.MoveDstSliceWindow( + c_grid_desc_mblock_mperblock_nblock_nperblock, + mxdlperwave_forward_step); + } + }); + + if constexpr(Block2CTileMap::ReductionStrategy == + StreamKReductionStrategy::Reduction) + { + if(is_sk_block) + { + // increase the counter for this tile + workgroup_barrier wg_barrier(p_semaphore); + wg_barrier.inc(tile_idx); + } + } + } + + // exit condition + iter_end -= current_iter_length; + if(iter_end <= iter_start) + break; + + if constexpr(Block2CTileMap::ReductionStrategy == StreamKReductionStrategy::Reduction) + { + block_acc_offset -= MPerBlock * NPerBlock; + } + // make sure next loop LDS is ready for use + block_sync_lds(); + } + } + + template + struct LStr + { + static std::string Get() { return ""; } + }; + + template <> + struct LStr + { + static std::string Get() { return "R"; } + }; + + template <> + struct LStr + { + static std::string Get() { return "C"; } + }; + + static std::string GetTypeString() + { + auto str = std::stringstream(); + + // clang-format off + str << "GemmXdlStreamK_" + << std::string(ALayout::name)[0] + << std::string(BLayout::name)[0] + << std::string(CLayout::name)[0] + << "_" + << "B" << BlockSize << "_" + << "Vec" << ABlockTransferSrcScalarPerVector << "x" + << BBlockTransferSrcScalarPerVector << "x" + << CBlockTransferScalarPerVector_NWaveNPerXDL << "_" + << MPerBlock << "x" + << NPerBlock << "x" + << K0PerBlock << "x" + << K1 ; + // clang-format on + + return str.str(); + } +}; + +} // namespace ck diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_v2r3.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_v2r3.hpp index f4504a940220bd8dc16814526844f67154097836..b6c146ae615e7a525c2970602b3ff22139972f15 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_v2r3.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_v2r3.hpp @@ -7,6 +7,7 @@ #include "ck/tensor_description/multi_index_transform_helper.hpp" #include "ck/tensor_description/tensor_descriptor.hpp" #include "ck/tensor_description/tensor_descriptor_helper.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" #include "ck/tensor_operation/gpu/grid/block_to_ctile_map.hpp" #include "ck/tensor_operation/gpu/grid/gridwise_gemm_pipeline_selector.hpp" #include "ck/tensor_operation/gpu/block/blockwise_gemm_xdlops.hpp" @@ -21,30 +22,24 @@ template __global__ void #if CK_USE_LAUNCH_BOUNDS __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) #endif - kernel_gemm_xdlops_v2r3( - const FloatAB* __restrict__ p_a_grid, - const FloatAB* __restrict__ p_b_grid, - FloatC* __restrict__ p_c_grid, - const AGridDesc_K0_M_K1 a_grid_desc_k0_m_k1, - const BGridDesc_K0_N_K1 b_grid_desc_k0_n_k1, - const CGridDesc_M0_N0_M1_N1_M2_M3_M4_N2 c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2, - const AElementwiseOperation a_element_op, - const BElementwiseOperation b_element_op, - const CElementwiseOperation c_element_op, - const Block2CTileMap block_2_ctile_map) +#if CK_USE_WAVES_PER_EU + __attribute__((amdgpu_waves_per_eu(CK_MIN_WAVES_PER_EU, CK_MAX_WAVES_PER_EU))) +#endif + kernel_gemm_xdlops_v2r3(const FloatAB* __restrict__ p_a_grid, + const FloatAB* __restrict__ p_b_grid, + FloatC* __restrict__ p_c_grid, + const AGridDesc_K0_M_K1 a_grid_desc_k0_m_k1, + const BGridDesc_K0_N_K1 b_grid_desc_k0_n_k1, + const CGridDesc_M_N c_grid_desc_m_n) { #if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \ - defined(__gfx940__)) + defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__)) __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()]; GridwiseGemm::template Run(p_a_grid, @@ -53,22 +48,49 @@ __global__ void p_shared, a_grid_desc_k0_m_k1, b_grid_desc_k0_n_k1, - c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2, - a_element_op, - b_element_op, - c_element_op, - block_2_ctile_map); + c_grid_desc_m_n); #else ignore = p_a_grid; ignore = p_b_grid; ignore = p_c_grid; ignore = a_grid_desc_k0_m_k1; ignore = b_grid_desc_k0_n_k1; - ignore = c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2; - ignore = a_element_op; - ignore = b_element_op; - ignore = c_element_op; - ignore = block_2_ctile_map; + ignore = c_grid_desc_m_n; +#endif // end of if (defined(__gfx908__) || defined(__gfx90a__)) +} + +template +__global__ void +#if CK_USE_LAUNCH_BOUNDS + __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) +#endif +#if CK_USE_WAVES_PER_EU + __attribute__((amdgpu_waves_per_eu(CK_MIN_WAVES_PER_EU, CK_MAX_WAVES_PER_EU))) +#endif + kernel_gemm_xdlops_v2r3(const typename GridwiseGemm::Argument karg) +{ +#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \ + defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__)) + __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()]; + + const auto a_grid_desc_k0_m_k1 = + amd_wave_read_first_lane(GridwiseGemm::MakeAGridDescriptor_K0_M_K1( + karg.M, karg.MPadded, karg.K, karg.K0, karg.StrideA)); + const auto b_grid_desc_k0_n_k1 = + amd_wave_read_first_lane(GridwiseGemm::MakeBGridDescriptor_K0_N_K1( + karg.K, karg.N, karg.NPadded, karg.K0, karg.StrideB)); + const auto c_grid_desc_m_n = amd_wave_read_first_lane(GridwiseGemm::MakeCGridDescriptor_M_N( + karg.M, karg.MPadded, karg.N, karg.NPadded, karg.StrideC)); + + GridwiseGemm::template Run(karg.p_a_grid, + karg.p_b_grid, + karg.p_c_grid, + p_shared, + a_grid_desc_k0_m_k1, + b_grid_desc_k0_n_k1, + c_grid_desc_m_n); +#else + ignore = karg; #endif // end of if (defined(__gfx908__) || defined(__gfx90a__)) } @@ -77,9 +99,6 @@ template ; - using GridwiseGemmPipe = remove_cvref_t())>; + __host__ static auto CalculateGridSize(index_t M, index_t N) + { + return std::make_tuple(Block2CTileMap::CalculateGridSize(M, N), 1, 1); + } + + template + __host__ static auto CalculateGridSize(const CGridDesc_M_N& c_grid_desc_m_n) + { + return std::make_tuple(Block2CTileMap::CalculateGridSize(c_grid_desc_m_n), 1, 1); + } + + template + __host__ static auto CalculateGridSize(index_t M, index_t N) + { + return std::make_tuple(Block2CTileMap::CalculateGridSize(M, N), 1, 1); + } + + __host__ static auto CalculateMPadded(index_t M) + { + return math::integer_divide_ceil(M, MPerBlock) * MPerBlock; + } + + __host__ static auto CalculateNPadded(index_t N) + { + return math::integer_divide_ceil(N, NPerBlock) * NPerBlock; + } + + __host__ static auto CalculateK0(index_t K) { return math::integer_divide_floor(K, K1Value); } + + // Argument + struct Problem + { + __host__ Problem(index_t M_, + index_t N_, + index_t K_, + index_t StrideA_, + index_t StrideB_, + index_t StrideC_) + : M{M_}, + N{N_}, + K{K_}, + StrideA{StrideA_}, + StrideB{StrideB_}, + StrideC{StrideC_}, + MPadded{CalculateMPadded(M_)}, + NPadded{CalculateNPadded(N_)}, + K0{CalculateK0(K_)} + { + } + + __host__ void Print() const + { + std::cout << "problem {" + << "M:" << M << ", " + << "N:" << N << ", " + << "K:" << K << ", " + << "SA:" << StrideA << ", " + << "SB:" << StrideB << ", " + << "SC:" << StrideC << ", " + << "MP:" << MPadded << ", " + << "NP:" << NPadded << ", " + << "K0:" << K0 << "}" << std::endl; + } + + index_t M; + index_t N; + index_t K; + index_t StrideA; + index_t StrideB; + index_t StrideC; + index_t MPadded; + index_t NPadded; + index_t K0; + }; + + // Argument + struct Argument : public Problem, public tensor_operation::device::BaseArgument + { + __host__ Argument(const FloatAB* p_a_grid_, + const FloatAB* p_b_grid_, + FloatC* p_c_grid_, + index_t M_, + index_t N_, + index_t K_, + index_t StrideA_, + index_t StrideB_, + index_t StrideC_) + : Problem{M_, N_, K_, StrideA_, StrideB_, StrideC_}, + p_a_grid{p_a_grid_}, + p_b_grid{p_b_grid_}, + p_c_grid{p_c_grid_} + { + } + + const FloatAB* p_a_grid; + const FloatAB* p_b_grid; + FloatC* p_c_grid; + }; + + using GridwiseGemmPipe = remove_cvref_t< + decltype(GridwiseGemmPipeline_Selector())>; // denorm test fix, required to work around fp16 mfma issue // we convert fp16->fp32->bf16 and execute bf16 mfma instruction @@ -204,13 +322,11 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3 return (a_block_space_size_aligned + b_block_space_size_aligned) * sizeof(FloatAB); } - // block_id to matrix tile idx (m0, n0) mapping are controlled by {M01, N01} - template + template __host__ __device__ static constexpr bool CheckValidity(const AGridDesc_K0_M_K1& a_grid_desc_k0_m_k1, const BGridDesc_K0_N_K1& b_grid_desc_k0_n_k1, - const CGridDesc_M_N& c_grid_desc_m_n, - const Block2CTileMap& block_2_ctile_map) + const CGridDesc_M_N& c_grid_desc_m_n) { static_assert(is_known_at_compile_time>::value, "wrong! K1 need to be known at compile-time"); @@ -239,7 +355,24 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3 return false; } - if(!block_2_ctile_map.CheckValidity(c_grid_desc_m_n)) + // TODO: also check validity of all components (blockwise-copy, threadwise-copy, etc) + return true; + } + + __host__ static constexpr bool CheckValidity(const Problem& problem) + { + static_assert(is_known_at_compile_time>::value, + "wrong! K1 need to be known at compile-time"); + + static_assert((MPerBlock % (MPerXDL * MXdlPerWave) == 0) && + (NPerBlock % (NXdlPerWave * NPerXDL)) == 0, + "Invalid tuning param!"); + + // check gridwise gemm pipeline + const index_t K0 = problem.K / K1Value; + const auto num_k_loop = K0 / K0PerBlock; + + if(!GridwiseGemmPipe::IsSupported(num_k_loop)) { return false; } @@ -248,15 +381,16 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3 return true; } - __host__ __device__ static constexpr bool CalculateHasMainKBlockLoop(index_t K) + __host__ static constexpr bool CalculateHasMainKBlockLoop(index_t K) { - const index_t num_loop = K / (K0PerBlock * K1); + const index_t num_loop = math::integer_divide_ceil(K, K0PerBlock * K1); return GridwiseGemmPipe::CalculateHasMainLoop(num_loop); } + template __host__ __device__ static constexpr auto - MakeCGridDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(const CGridDesc_M_N& c_grid_desc_m_n) + MakeCGridDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(const CGridDesc& c_grid_desc_m_n) { constexpr auto max_lds_align = K1; @@ -306,31 +440,23 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3 } // return block_id to C matrix tile idx (m0, n0) mapping - __host__ __device__ static constexpr auto MakeDefaultBlock2CTileMap( - const CGridDesc_M_N& c_grid_desc_m_n, index_t /* M01 */, index_t /* N01 */) + using Block2CTileMap = BlockToCTileMap_M00_N0_M01Adapt; + + template + __device__ static void Run(const FloatAB* p_a_grid, + const FloatAB* p_b_grid, + FloatC* p_c_grid, + void* __restrict__ p_shared, + const AGridDesc_K0_M_K1& a_grid_desc_k0_m_k1, + const BGridDesc_K0_N_K1& b_grid_desc_k0_n_k1, + const CGridDesc_M_N& c_grid_desc_m_n) { - return BlockToCTileMap_M00_N0_M01Adapt( - c_grid_desc_m_n); - } + const auto c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2 = + MakeCGridDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(c_grid_desc_m_n); - using CGridDesc_M0_N0_M1_N1_M2_M3_M4_N2 = - decltype(MakeCGridDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(CGridDesc_M_N{})); - using DefaultBlock2CTileMap = decltype(MakeDefaultBlock2CTileMap(CGridDesc_M_N{}, 1, 1)); - - template - __device__ static void - Run(const FloatAB* __restrict__ p_a_grid, - const FloatAB* __restrict__ p_b_grid, - FloatC* __restrict__ p_c_grid, - void* __restrict__ p_shared, - const AGridDesc_K0_M_K1& a_grid_desc_k0_m_k1, - const BGridDesc_K0_N_K1& b_grid_desc_k0_n_k1, - const CGridDesc_M0_N0_M1_N1_M2_M3_M4_N2& c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2, - const AElementwiseOperation& a_element_op, - const BElementwiseOperation& b_element_op, - const CElementwiseOperation& c_element_op, - const Block2CTileMap& block_2_ctile_map) - { const auto a_grid_buf = make_dynamic_buffer( p_a_grid, a_grid_desc_k0_m_k1.GetElementSpaceSize()); const auto b_grid_buf = make_dynamic_buffer( @@ -338,7 +464,12 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3 auto c_grid_buf = make_dynamic_buffer( p_c_grid, c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2.GetElementSpaceSize()); - const auto K0 = a_grid_desc_k0_m_k1.GetLength(I0); + const AElementwiseOperation a_element_op{}; + const BElementwiseOperation b_element_op{}; + const CElementwiseOperation c_element_op{}; + + const auto block_2_ctile_map = + Block2CTileMap{c_grid_desc_m_n.GetLength(I0), c_grid_desc_m_n.GetLength(I1)}; // divide block work by [M, N] const auto block_work_idx = @@ -467,6 +598,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3 constexpr auto b_block_slice_copy_step = make_multi_index(K0PerBlock, 0, 0); // gridwise GEMM pipeline + const auto K0 = a_grid_desc_k0_m_k1.GetLength(I0); const index_t num_k_block_main_loop = __builtin_amdgcn_readfirstlane(K0 / K0PerBlock); GridwiseGemmPipe::template Run(a_grid_desc_k0_m_k1, @@ -565,4 +697,346 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3 } }; +template +struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3_ext + : GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3 +{ + using Parent = + GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3; + + using typename Parent::GridwiseGemmPipe; + using typename Parent::Problem; + + using Parent::I1; + + using Parent::K1; + + __device__ static auto + MakeAGridDescriptor_K0_M_K1(index_t M, index_t MPad, index_t K, index_t K0, index_t StrideA) + { + const auto a_grid_desc_m_k = [&]() { + if constexpr(is_same::value) + { + return make_naive_tensor_descriptor(make_tuple(M, K), make_tuple(StrideA, I1)); + } + else if constexpr(is_same::value) + { + return make_naive_tensor_descriptor(make_tuple(M, K), make_tuple(I1, StrideA)); + } + }(); + + if constexpr(GemmSpec == tensor_operation::device::GemmSpecialization::MNKPadding) + { + const auto K0Pad = math::integer_divide_ceil(K0, K0PerBlock) * K0PerBlock; + const auto KPad = K0Pad * K1Value; + + const auto a_grid_desc_m_kpad = transform_tensor_descriptor( + a_grid_desc_m_k, + make_tuple(make_pass_through_transform(M), make_right_pad_transform(K, KPad - K)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + + return transform_tensor_descriptor( + a_grid_desc_m_kpad, + make_tuple(make_unmerge_transform(make_tuple(K0Pad, K1Value)), + make_right_pad_transform(M, MPad - M)), + make_tuple(Sequence<1>{}, Sequence<0>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + } + else if constexpr(GemmSpec == tensor_operation::device::GemmSpecialization::MNPadding) + { + return transform_tensor_descriptor( + a_grid_desc_m_k, + make_tuple(make_unmerge_transform(make_tuple(K0, K1Value)), + make_right_pad_transform(M, MPad - M)), + make_tuple(Sequence<1>{}, Sequence<0>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + } + else + { + return transform_tensor_descriptor( + a_grid_desc_m_k, + make_tuple(make_unmerge_transform(make_tuple(K0, K1Value)), + make_pass_through_transform(M)), + make_tuple(Sequence<1>{}, Sequence<0>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + } + } + + __device__ static auto + MakeBGridDescriptor_K0_N_K1(index_t K, index_t N, index_t NPad, index_t K0, index_t StrideB) + { + const auto b_grid_desc_k_n = [&]() { + if constexpr(is_same::value) + { + return make_naive_tensor_descriptor(make_tuple(K, N), make_tuple(StrideB, I1)); + } + else if constexpr(is_same::value) + { + return make_naive_tensor_descriptor(make_tuple(K, N), make_tuple(I1, StrideB)); + } + }(); + + if constexpr(GemmSpec == tensor_operation::device::GemmSpecialization::MNKPadding) + { + const auto K0Pad = math::integer_divide_ceil(K0, K0PerBlock) * K0PerBlock; + const auto KPad = K0Pad * K1Value; + + const auto b_grid_desc_kpad_n = transform_tensor_descriptor( + b_grid_desc_k_n, + make_tuple(make_right_pad_transform(K, KPad - K), make_pass_through_transform(N)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + + return transform_tensor_descriptor( + b_grid_desc_kpad_n, + make_tuple(make_unmerge_transform(make_tuple(K0Pad, K1Value)), + make_right_pad_transform(N, NPad - N)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + } + + else if constexpr(GemmSpec == tensor_operation::device::GemmSpecialization::MNPadding) + { + return transform_tensor_descriptor( + b_grid_desc_k_n, + make_tuple(make_unmerge_transform(make_tuple(K0, K1Value)), + make_right_pad_transform(N, NPad - N)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + } + else + { + return transform_tensor_descriptor( + b_grid_desc_k_n, + make_tuple(make_unmerge_transform(make_tuple(K0, K1Value)), + make_pass_through_transform(N)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + } + } + + __device__ static auto + MakeCGridDescriptor_M_N(index_t M, index_t MPad, index_t N, index_t NPad, index_t StrideC) + { + const auto c_grid_desc_m_n = [&]() { + if constexpr(is_same::value) + { + return make_naive_tensor_descriptor(make_tuple(M, N), make_tuple(StrideC, I1)); + } + else if constexpr(is_same::value) + { + return make_naive_tensor_descriptor(make_tuple(M, N), make_tuple(I1, StrideC)); + } + }(); + + if constexpr(GemmSpec == tensor_operation::device::GemmSpecialization::MNPadding) + { + return transform_tensor_descriptor(c_grid_desc_m_n, + make_tuple(make_right_pad_transform(M, MPad - M), + make_right_pad_transform(N, NPad - N)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + } + else + { + + return transform_tensor_descriptor( + c_grid_desc_m_n, + make_tuple(make_pass_through_transform(M), make_pass_through_transform(N)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + } + } + + __host__ static constexpr bool CheckValidity(const Problem& problem) + { + static_assert(is_known_at_compile_time>::value, + "wrong! K1 need to be known at compile-time"); + + static_assert((MPerBlock % (MPerXDL * MXdlPerWave) == 0) && + (NPerBlock % (NXdlPerWave * NPerXDL)) == 0, + "Invalid tuning param!"); + + if constexpr(!(GemmSpec == tensor_operation::device::GemmSpecialization::MPadding || + GemmSpec == tensor_operation::device::GemmSpecialization::MNPadding || + GemmSpec == tensor_operation::device::GemmSpecialization::MKPadding || + GemmSpec == tensor_operation::device::GemmSpecialization::MNKPadding)) + { + if(!(problem.M % MPerBlock == 0)) + { + return false; + } + } + + if constexpr(!(GemmSpec == tensor_operation::device::GemmSpecialization::NPadding || + GemmSpec == tensor_operation::device::GemmSpecialization::MNPadding || + GemmSpec == tensor_operation::device::GemmSpecialization::NKPadding || + GemmSpec == tensor_operation::device::GemmSpecialization::MNKPadding)) + { + if(!(problem.N % NPerBlock == 0)) + { + return false; + } + } + + if constexpr(is_same::value) + { + if(problem.K % ABlockTransferSrcScalarPerVector != 0) + { + return false; + } + } + else + { + if(problem.M % ABlockTransferSrcScalarPerVector != 0) + { + return false; + } + } + + if constexpr(is_same::value) + { + if(problem.N % BBlockTransferSrcScalarPerVector != 0) + { + return false; + } + } + else + { + if(problem.K % BBlockTransferSrcScalarPerVector != 0) + { + return false; + } + } + + // check gridwise gemm pipeline + const index_t K0 = problem.K / K1; + const auto num_k_loop = K0 / K0PerBlock; + + if(!GridwiseGemmPipe::IsSupported(num_k_loop)) + { + return false; + } + + // TODO: also check validity of all components (blockwise-copy, threadwise-copy, etc) + return true; + } +}; + } // namespace ck diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_v2r4.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_v2r4.hpp index 55f465a037f438287d176cb131b3ae96c0d5138d..19fbee727f77ae8b6b6f90a5a6f520accde95332 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_v2r4.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_v2r4.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #pragma once @@ -43,7 +43,7 @@ __global__ void const CBlockClusterAdaptor c_block_cluster_adaptor) { #if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \ - defined(__gfx940__)) + defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__)) constexpr index_t shared_block_size = GridwiseGemm::GetSharedMemoryNumberOfByte() / sizeof(FloatAB); diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_v2r4r2.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_v2r4r2.hpp index d56d1986e274e4aee3400e2b0def6245a7bf78fd..371281dfe20dddacf861677a771c3a425028a27f 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_v2r4r2.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_v2r4r2.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #pragma once @@ -31,7 +31,7 @@ __global__ void const Block2CTileMap& b2c_map) { #if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \ - defined(__gfx940__)) + defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__)) constexpr index_t shared_size = GridwiseGemm::GetSharedMemoryNumberOfByte(); __shared__ uint8_t p_shared[shared_size]; @@ -45,7 +45,8 @@ __global__ void } template + PipelineVersion PipelineVer = PipelineVersion::v1, + typename ComputeType = FloatC> struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4r2 { static constexpr auto I0 = Number<0>{}; @@ -108,13 +110,13 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4r2 using ThisThreadBlock = ThisThreadBlock; - using GridwiseGemmPipe = remove_cvref_t())>; + using GridwiseGemmPipe = remove_cvref_t< + decltype(GridwiseGemmPipeline_Selector())>; struct Argument : public ck::tensor_operation::device::BaseArgument { - const FloatAB* p_a_grid; - const FloatAB* p_b_grid; + const FloatA* p_a_grid; + const FloatB* p_b_grid; FloatC* p_c_grid; index_t M; index_t N; @@ -128,8 +130,8 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4r2 index_t K0; index_t k_batch; - Argument(const FloatAB* p_a_grid_, - const FloatAB* p_b_grid_, + Argument(const FloatA* p_a_grid_, + const FloatB* p_b_grid_, FloatC* p_c_grid_, index_t M_, index_t N_, @@ -365,7 +367,7 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4r2 constexpr auto c_block_size = GetCBlockDescriptor_MBlock_MPerBlock_NBlock_NPerBlock().GetElementSpaceSize(); - return math::max((a_block_space_size + b_block_space_size) * sizeof(FloatAB), + return math::max((a_block_space_size + b_block_space_size) * sizeof(ComputeType), c_block_size * sizeof(FloatC)); } @@ -577,8 +579,8 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4r2 void* __restrict__ p_shared_block, const Block2CTileMap& block_2_ctile_map) { - const FloatAB* p_a_grid = karg.p_a_grid; - const FloatAB* p_b_grid = karg.p_b_grid; + const FloatA* p_a_grid = karg.p_a_grid; + const FloatB* p_b_grid = karg.p_b_grid; FloatC* p_c_grid = karg.p_c_grid; const auto a_b_k0_m_k1_grid_desc = MakeAGridDescriptor_KBatch_K0_M_K1( karg.M, karg.MPadded, karg.K, karg.StrideA, karg.k_batch, karg.K0, karg.KPadded); @@ -698,8 +700,8 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4r2 Sequence<1, K0PerBlock, MPerBlock, K1>, ABlockTransferThreadClusterLengths_K0_M_K1, ABlockTransferThreadClusterArrangeOrder, - FloatAB, - FloatAB, + FloatA, + ComputeType, decltype(a_b_k0_m_k1_grid_desc), decltype(a_b_k0_m_k1_block_desc), ABlockTransferSrcAccessOrder, @@ -728,8 +730,8 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4r2 Sequence<1, K0PerBlock, NPerBlock, K1>, BBlockTransferThreadClusterLengths_K0_N_K1, BBlockTransferThreadClusterArrangeOrder, - FloatAB, - FloatAB, + FloatB, + ComputeType, decltype(b_b_k0_n_k1_grid_desc), decltype(b_b_k0_n_k1_block_desc), BBlockTransferSrcAccessOrder, @@ -759,7 +761,7 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4r2 auto blockwise_gemm = BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_Selector< BlockSize, - FloatAB, + ComputeType, FloatAcc, decltype(a_k0_m_k1_block_desc), decltype(b_k0_n_k1_block_desc), @@ -776,8 +778,8 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4r2 constexpr auto a_block_space_size = math::integer_least_multiple(a_k0_m_k1_block_desc.GetElementSpaceSize(), max_lds_align); - FloatAB* p_a_block = static_cast(p_shared_block); - FloatAB* p_b_block = static_cast(p_shared_block) + a_block_space_size; + ComputeType* p_a_block = static_cast(p_shared_block); + ComputeType* p_b_block = static_cast(p_shared_block) + a_block_space_size; constexpr auto a_block_slice_copy_step = make_multi_index(0, K0PerBlock, 0, 0); constexpr auto b_block_slice_copy_step = make_multi_index(0, K0PerBlock, 0, 0); @@ -787,53 +789,6 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4r2 auto b_block_buf = make_dynamic_buffer( p_b_block, b_k0_n_k1_block_desc.GetElementSpaceSize()); -#if 0 - // preload data into LDS - { - a_blockwise_copy.RunRead(a_b_k0_m_k1_grid_desc, a_grid_buf); - b_blockwise_copy.RunRead(b_b_k0_n_k1_grid_desc, b_grid_buf); - - a_blockwise_copy.RunWrite(a_b_k0_m_k1_block_desc, a_block_buf); - b_blockwise_copy.RunWrite(b_b_k0_n_k1_block_desc, b_block_buf); - } - - // Initialize C - c_thread_buf.Clear(); - - // main body - if constexpr(HasMainKBlockLoop) - { - index_t k0_block_data_begin = 0; - - do - { - a_blockwise_copy.MoveSrcSliceWindow(a_b_k0_m_k1_grid_desc, a_block_slice_copy_step); - b_blockwise_copy.MoveSrcSliceWindow(b_b_k0_n_k1_grid_desc, b_block_slice_copy_step); - - a_blockwise_copy.RunRead(a_b_k0_m_k1_grid_desc, a_grid_buf); - - block_sync_lds(); - - b_blockwise_copy.RunRead(b_b_k0_n_k1_grid_desc, b_grid_buf); - - blockwise_gemm.Run(a_block_buf, b_block_buf, c_thread_buf); - - block_sync_lds(); - - a_blockwise_copy.RunWrite(a_b_k0_m_k1_block_desc, a_block_buf); - b_blockwise_copy.RunWrite(b_b_k0_n_k1_block_desc, b_block_buf); - - k0_block_data_begin += K0PerBlock; - } while(k0_block_data_begin < (karg.K0 - K0PerBlock)); - } - - // tail - { - block_sync_lds(); - - blockwise_gemm.Run(a_block_buf, b_block_buf, c_thread_buf); - } -#else // gridwise GEMM pipeline const index_t num_k_block_main_loop = __builtin_amdgcn_readfirstlane( (a_b_k0_m_k1_grid_desc.GetLength(I1) * a_b_k0_m_k1_grid_desc.GetLength(I3)) / @@ -856,7 +811,6 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4r2 blockwise_gemm, c_thread_buf, num_k_block_main_loop); -#endif // output: register to global memory { diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_v3r1.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_v3r1.hpp index 8259927fecefa8bacd416289530c5b359b9aa5a4..8d7bd9a8d1eeae302e97fd465ba1678eac1b8bba 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_v3r1.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_v3r1.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #pragma once @@ -47,7 +47,7 @@ __global__ void const Block2CTileMap block_2_ctile_map) { #if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \ - defined(__gfx940__)) + defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__)) __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()]; GridwiseGemm::template Run( @@ -139,8 +139,8 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v3r1 using ThisThreadBlock = ThisThreadBlock; - using GridwiseGemmPipe = remove_cvref_t())>; + using GridwiseGemmPipe = remove_cvref_t< + decltype(GridwiseGemmPipeline_Selector())>; __host__ __device__ static constexpr auto GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1() { @@ -315,8 +315,8 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v3r1 c_grid_desc_m_n); } using CGridDescriptor_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl = - remove_cvref_t; using DefaultBlock2CTileMap = @@ -634,10 +634,8 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v3r1 Sequence<0, 1, 2, 3, 4, 5>, // typename ThreadClusterArrangeOrder, FloatCShuffle, // typename SrcData, FloatC, // typename DstData, - decltype( - c_block_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl), - decltype( - c_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl), + decltype(c_block_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl), + decltype(c_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl), Sequence<0, 1, 2, 3, 4, 5>, // typename DimAccessOrder, 5, // index_t VectorDim, CBlockTransferScalarPerVector_NWaveNPerXdl, // index_t ScalarPerVector, diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_v3r2.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_v3r2.hpp index 5d5fdae170b17970b6105406b8e19889dd766549..79202cb5cfc90bef9f3de0781e67655b3650fd71 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_v3r2.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_v3r2.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #pragma once @@ -50,7 +50,7 @@ __global__ void const Block2CTileMap block_2_ctile_map) { #if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \ - defined(__gfx940__)) + defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__)) __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()]; GridwiseGemm::template Run( @@ -142,8 +142,8 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v3r2 using ThisThreadBlock = ThisThreadBlock; - using GridwiseGemmPipe = remove_cvref_t())>; + using GridwiseGemmPipe = remove_cvref_t< + decltype(GridwiseGemmPipeline_Selector())>; __host__ __device__ static constexpr auto GetABlockDescriptor_K0PerBlock_MPerBlock_K1() { @@ -323,13 +323,13 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v3r2 } using CGridDescriptor_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl = - remove_cvref_t; using C0GridDescriptor_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl = - remove_cvref_t; using DefaultBlock2CTileMap = @@ -654,12 +654,9 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v3r2 FloatC, // typename Src0Data, FloatC, // typename Src1Data, FloatC, // typename DstData, - decltype( - c_block_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl), - decltype( - c0_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl), - decltype( - c_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl), + decltype(c_block_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl), + decltype(c0_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl), + decltype(c_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl), Sequence<0, 1, 2, 3, 4, 5>, // typename DimAccessOrder, 5, // index_t VectorDim, CBlockTransferScalarPerVector_NWaveNPerXdl, // index_t ScalarPerVector, diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_v3r3.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_v3r3.hpp index dc83f8e98493d209104fd4cbdd7f1254b2de72a2..0d461b4fbf4cc0d78cfe347ea4fe6d63d1f28db6 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_v3r3.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_v3r3.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #pragma once @@ -54,7 +54,7 @@ __global__ void const Block2CTileMap block_2_ctile_map) { #if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \ - defined(__gfx940__)) + defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__)) __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()]; GridwiseGemm::template Run( @@ -151,8 +151,8 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v3r3 using ThisThreadBlock = ThisThreadBlock; - using GridwiseGemmPipe = remove_cvref_t())>; + using GridwiseGemmPipe = remove_cvref_t< + decltype(GridwiseGemmPipeline_Selector())>; __host__ __device__ static constexpr auto GetABlockDescriptor_K0PerBlock_MPerBlock_K1() { @@ -331,18 +331,18 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v3r3 c_grid_desc_m_n); } using CGridDescriptor_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl = - remove_cvref_t; using C0GridDescriptor_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl = - remove_cvref_t; using C1GridDescriptor_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl = - remove_cvref_t; using DefaultBlock2CTileMap = @@ -674,14 +674,10 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v3r3 FloatC, // typename Src1Data, FloatC, // typename Src2Data, FloatC, // typename DstData, - decltype( - c_block_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl), - decltype( - c0_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl), - decltype( - c1_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl), - decltype( - c_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl), + decltype(c_block_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl), + decltype(c0_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl), + decltype(c1_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl), + decltype(c_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl), Sequence<0, 1, 2, 3, 4, 5>, // typename DimAccessOrder, 5, // index_t VectorDim, CBlockTransferScalarPerVector_NWaveNPerXdl, // index_t ScalarPerVector, diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_image_to_column.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_image_to_column.hpp new file mode 100644 index 0000000000000000000000000000000000000000..93625a324e542cfcdb6b5bc4c583b41e21192958 --- /dev/null +++ b/include/ck/tensor_operation/gpu/grid/gridwise_image_to_column.hpp @@ -0,0 +1,97 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck/utility/common_header.hpp" +#include "ck/tensor_description/multi_index_transform_helper.hpp" +#include "ck/tensor_description/tensor_descriptor.hpp" +#include "ck/tensor_description/tensor_descriptor_helper.hpp" +#include "ck/tensor_operation/gpu/grid/block_to_ctile_map.hpp" +#include "ck/tensor_operation/gpu/grid/gridwise_gemm_pipeline_selector.hpp" +#include "ck/tensor_operation/gpu/block/blockwise_gemm_xdlops.hpp" +#include "ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v7.hpp" +#include "ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" + +namespace ck { + +template +struct GridwiseImageToColumn +{ + + static constexpr auto I0 = Number<0>{}; + static constexpr auto I1 = Number<1>{}; + + using ThisThreadBlock = ThisThreadBlock; + + __device__ static void Run(const InputGridDesc& in_grid_desc, + const InputDataType* __restrict__ p_in_global, + const OutputGridDesc& out_grid_desc, + OutputDataType* __restrict__ p_out_global, + const Block2ETileMap& block_2_tile_map) + { + const auto block_work_idx = + block_2_tile_map.CalculateBottomIndex(make_multi_index(get_block_1d_id())); + + const index_t m_block_data_idx_on_grid = + __builtin_amdgcn_readfirstlane(block_work_idx[I0] * MPerBlock); + + const index_t k_block_data_idx_on_grid = + __builtin_amdgcn_readfirstlane(block_work_idx[I1] * KPerBlock); + + // Global Memory + const auto in_global_buf = make_dynamic_buffer( + p_in_global, in_grid_desc.GetElementSpaceSize()); + auto out_global_buf = make_dynamic_buffer( + p_out_global, out_grid_desc.GetElementSpaceSize()); + + auto copy_global_to_global = ThreadGroupTensorSliceTransfer_v7< + ThisThreadBlock, + Tuple, + Tuple, + decltype(tie(in_grid_desc)), + decltype(tie(out_grid_desc)), + tensor_operation::element_wise::PassThrough, + Sequence(InMemoryDataOperationEnum::Set)>, + Sequence, + ThreadClusterLengths, + Sequence<0, 1>, + Sequence<0, 1>, + I1, + ScalarPerVector, + Sequence, + Sequence>{ + in_grid_desc, + make_tuple(make_multi_index(m_block_data_idx_on_grid, k_block_data_idx_on_grid)), + out_grid_desc, + make_tuple(make_multi_index(m_block_data_idx_on_grid, k_block_data_idx_on_grid)), + tensor_operation::element_wise::PassThrough{}}; + + copy_global_to_global.Run( + tie(in_grid_desc), tie(in_global_buf), tie(out_grid_desc), tie(out_global_buf)); + } + + __host__ static constexpr bool CheckValidity(const InputGridDesc& in_grid_desc, + const OutputGridDesc& out_grid_desc) + { + if(in_grid_desc.GetLength(I0) % MPerBlock != 0 || + in_grid_desc.GetLength(I1) % KPerBlock != 0) + return false; + if(out_grid_desc.GetLength(I0) % MPerBlock != 0 || + out_grid_desc.GetLength(I1) % KPerBlock != 0) + return false; + return true; + } +}; + +} // namespace ck diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_permute.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_permute.hpp index de1ae915920503a086d2cf939ae6e41d27a28838..61d0f9e0d5573c774bbb1dcba707e8c2e775f075 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_permute.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_permute.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_put_element_1d.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_put_element_1d.hpp new file mode 100644 index 0000000000000000000000000000000000000000..c41eef8c45fa264a85237c116d223db8e42c48b2 --- /dev/null +++ b/include/ck/tensor_operation/gpu/grid/gridwise_put_element_1d.hpp @@ -0,0 +1,155 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck/utility/data_type.hpp" +#include "ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp" + +namespace ck { + +template +__global__ void kernel_put_element_1d(const InGrid1dDesc in_grid_1d_desc, + const InDataType* __restrict__ p_in_global, + const IndexDataType* __restrict__ p_indices_global, + OutDataType* __restrict__ p_out_global, + const ElementwiseOperation elementwise_op) +{ + GridwisePutElementwise1dFunctor::Run( + in_grid_1d_desc, p_in_global, p_indices_global, p_out_global, elementwise_op); +} + +// output[indices] = input +template +struct GridwisePutElement_1D +{ + static constexpr auto I0 = Number<0>{}; + + static constexpr auto thread_buffer_desc_m = + make_naive_tensor_descriptor_packed(make_tuple(Number{})); + + __device__ static void Run(const InGrid1dDesc& in_grid_1d_desc, + const InDataType* __restrict__ p_in_global, + const IndexDataType* __restrict__ p_indices_global, + OutDataType* __restrict__ p_out_global, + const ElementwiseOperation& elementwise_op) + { + // Global Memory + const auto in_global_buf = make_dynamic_buffer( + p_in_global, in_grid_1d_desc.GetElementSpaceSize()); + + const auto indices_global_buf = + make_dynamic_buffer(p_indices_global, + in_grid_1d_desc.GetElementSpaceSize(), + NumericLimits::Lowest()); + + // VGPR + StaticBuffer in_thread_buf; + StaticBuffer indices_thread_buf; + + // Thread id, Block id and index + const index_t thread_global_id = get_thread_global_1d_id(); + const auto thread_global_offset = make_multi_index(thread_global_id * InVectorSize); + const index_t blockSize = get_block_size(); + const index_t blockPerGrid = get_grid_size(); + const auto M = in_grid_1d_desc.GetLength(I0); + const index_t loop_step = blockPerGrid * blockSize * InVectorSize; + const auto loop_step_index = make_multi_index(loop_step); + + auto in_global_load = + ThreadwiseTensorSliceTransfer_v2, // SliceLengths + Sequence<0>, // DimAccessOrder + 0, // SrcVectorDim + InVectorSize, // ScalarPerVector + 1, // SrcScalarStrideInVector + false>{in_grid_1d_desc, thread_global_offset}; + + auto indices_global_load = + ThreadwiseTensorSliceTransfer_v2, // SliceLengths + Sequence<0>, // DimAccessOrder + 0, // SrcVectorDim + InVectorSize, // ScalarPerVector + 1, // SrcScalarStrideInVector + false>{in_grid_1d_desc, thread_global_offset}; + + index_t num_iter = M / loop_step; + do + { + in_global_load.Run(in_grid_1d_desc, + in_global_buf, + thread_buffer_desc_m, + make_tuple(I0), + in_thread_buf); + + in_global_load.MoveSrcSliceWindow(in_grid_1d_desc, loop_step_index); + + static_for<0, InVectorSize, 1>{}( + [&](auto iM) { elementwise_op(in_thread_buf(iM), in_thread_buf[iM]); }); + + indices_global_load.Run(in_grid_1d_desc, + indices_global_buf, + thread_buffer_desc_m, + make_tuple(I0), + indices_thread_buf); + + indices_global_load.MoveSrcSliceWindow(in_grid_1d_desc, loop_step_index); + + static_for<0, InVectorSize, 1>{}([&](auto iM) { + if(indices_thread_buf[iM] >= 0) + { + if constexpr(MemOp == InMemoryDataOperationEnum::Set) + { + // User should guarantee each index in p_indices_global is different + *(p_out_global + indices_thread_buf[iM]) = + ck::type_convert(in_thread_buf[iM]); + } + else if constexpr(MemOp == InMemoryDataOperationEnum::AtomicAdd) + { + atomic_add(p_out_global + indices_thread_buf[iM], + ck::type_convert(in_thread_buf[iM])); + } + else if constexpr(MemOp == InMemoryDataOperationEnum::AtomicMax) + { + atomic_max(p_out_global + indices_thread_buf[iM], + ck::type_convert(in_thread_buf[iM])); + } + else if constexpr(MemOp == InMemoryDataOperationEnum::Add) + { + // User should guarantee each index in p_indices_global is different + *(p_out_global + indices_thread_buf[iM]) += + ck::type_convert(in_thread_buf[iM]); + } + else + { + static_assert(MemOp == InMemoryDataOperationEnum::Set || + MemOp == InMemoryDataOperationEnum::AtomicAdd || + MemOp == InMemoryDataOperationEnum::AtomicMax || + MemOp == InMemoryDataOperationEnum::Add); + } + } + }); + + } while(--num_iter); + } +}; + +} // namespace ck diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_set_buffer_value.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_set_buffer_value.hpp index 901e7aee98a7b6ab1c37b7dd4dc36f796632fef6..41352fabeb361b017949738bda31649a3ce0dd40 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_set_buffer_value.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_set_buffer_value.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_set_multiple_buffer_value.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_set_multiple_buffer_value.hpp index 88c7b6acfeb9c43ec0e8b5a5ab557f834430fb40..0ad36b418a42bff56eb87ba35bb5150808e9cac4 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_set_multiple_buffer_value.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_set_multiple_buffer_value.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_softmax.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_softmax.hpp index 0344e68305b1a951cfc41a53666090bd1e0113b6..5f56ac6fc4664be1685a950dc6e242b3e256e178 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_softmax.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_softmax.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_sparse_embeddings_forward_layernorm.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_sparse_embeddings_forward_layernorm.hpp index ff2511fa6e61eda6c4cca976d0372cf98840ee5c..287b4e5421682eb3596f0d47b46bf769d32377b0 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_sparse_embeddings_forward_layernorm.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_sparse_embeddings_forward_layernorm.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #pragma once @@ -78,8 +78,8 @@ struct GridwiseSparseEmbeddingsForwardLayernorm using ThreadwiseWolfordDesc2D = decltype(make_naive_tensor_descriptor_packed(make_tuple( Number{}, Number{}))); - using ThreadwiseWolfordDescReduce = decltype( - make_naive_tensor_descriptor_packed(make_tuple(Number{}))); + using ThreadwiseWolfordDescReduce = decltype(make_naive_tensor_descriptor_packed( + make_tuple(Number{}))); using ThreadwiseWelford = ThreadwiseWelford; diff --git a/include/ck/tensor_operation/gpu/grid/normalization/gridwise_normalization_naive_variance.hpp b/include/ck/tensor_operation/gpu/grid/normalization/gridwise_normalization_naive_variance.hpp index 792ffabcb90c71b4620915572c29091b67c8c4d6..c3f122106df29a471e376c453d055806c9aa2514 100644 --- a/include/ck/tensor_operation/gpu/grid/normalization/gridwise_normalization_naive_variance.hpp +++ b/include/ck/tensor_operation/gpu/grid/normalization/gridwise_normalization_naive_variance.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/include/ck/tensor_operation/gpu/grid/normalization/gridwise_normalization_selector.hpp b/include/ck/tensor_operation/gpu/grid/normalization/gridwise_normalization_selector.hpp index 632690e1eff9af4cb0ab655250dca381c625776a..e50fb9813325b08576f49d321e7a5de5d51bbff2 100644 --- a/include/ck/tensor_operation/gpu/grid/normalization/gridwise_normalization_selector.hpp +++ b/include/ck/tensor_operation/gpu/grid/normalization/gridwise_normalization_selector.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/include/ck/tensor_operation/gpu/grid/normalization/gridwise_normalization_splitk_1st.hpp b/include/ck/tensor_operation/gpu/grid/normalization/gridwise_normalization_splitk_1st.hpp index 129b4e1165d34a25087722d4d0c2c0bd8b160c30..80e9a84f9681f6c630db441a7f2de33826f19cf9 100644 --- a/include/ck/tensor_operation/gpu/grid/normalization/gridwise_normalization_splitk_1st.hpp +++ b/include/ck/tensor_operation/gpu/grid/normalization/gridwise_normalization_splitk_1st.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #pragma once @@ -78,17 +78,18 @@ struct GridwiseNormalizationSplitK1st static constexpr auto ThreadBufferNumber = Number{}; __device__ static int - GetKPerThread(int kRaw, int kGridSize, int block_k_cluster_id, int thread_k_cluster_id) + GetKPerThread(int k, int kRaw, int kGridSize, int block_k_cluster_id, int thread_k_cluster_id) { bool is_rightmost_block = block_k_cluster_id == kGridSize - 1; if(is_rightmost_block) { - int left_kPerBlock = math::integer_divide_ceil(kRaw, kGridSize); - int kPerBlock = kRaw % kGridSize == 0 ? left_kPerBlock : kRaw % left_kPerBlock; - int kPerThread = - kPerBlock < K_BlockTileSize ? 0 : KThreadSliceSize * (kPerBlock / K_BlockTileSize); - int kPerBlockTail = kPerBlock - kPerThread * KThreadClusterSize; + int left_kPerBlock = math::integer_divide_ceil(k, kGridSize); + int kRightmostBlock = kRaw - left_kPerBlock * (kGridSize - 1); + int kPerThread = kRightmostBlock < K_BlockTileSize + ? 0 + : KThreadSliceSize * (kRightmostBlock / K_BlockTileSize); + int kPerBlockTail = kRightmostBlock - kPerThread * KThreadClusterSize; if(kPerBlockTail > 0) { @@ -105,7 +106,7 @@ struct GridwiseNormalizationSplitK1st } else { - int kPerBlock = math::integer_divide_ceil(kRaw, kGridSize); + int kPerBlock = math::integer_divide_ceil(k, kGridSize); return KThreadSliceSize * (kPerBlock / K_BlockTileSize); } } @@ -193,10 +194,13 @@ struct GridwiseNormalizationSplitK1st auto var_global_val_buf = make_dynamic_buffer( p_variance_global, mean_var_grid_desc_m_kblock.GetElementSpaceSize()); - auto threadwise_welford = ThreadwiseWelford(); - int kRaw = x_grid_desc_m_k.GetTransforms()[I2].GetUpperLengths()[I0]; - threadwise_welford.max_count_ = - GetKPerThread(kRaw, k_grid_size, block_k_cluster_id, thread_k_cluster_id); + auto threadwise_welford = ThreadwiseWelford(); + int kRaw = x_grid_desc_m_k.GetTransforms()[I2].GetUpperLengths()[I0]; + threadwise_welford.max_count_ = GetKPerThread(x_grid_desc_m_k.GetLength(I1), + kRaw, + k_grid_size, + block_k_cluster_id, + thread_k_cluster_id); static_for<0, MThreadSliceSize, 1>{}([&](auto I) { mean_thread_buf(I) = type_convert(0.0f); diff --git a/include/ck/tensor_operation/gpu/grid/normalization/gridwise_normalization_splitk_2nd.hpp b/include/ck/tensor_operation/gpu/grid/normalization/gridwise_normalization_splitk_2nd.hpp index d796d1afc0c7960f80ffda59ef50e01f1e5bf98b..136ac94e7f0fabb81cfd5baea7e58174799f4d43 100644 --- a/include/ck/tensor_operation/gpu/grid/normalization/gridwise_normalization_splitk_2nd.hpp +++ b/include/ck/tensor_operation/gpu/grid/normalization/gridwise_normalization_splitk_2nd.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/include/ck/tensor_operation/gpu/grid/normalization/gridwise_normalization_welford_variance.hpp b/include/ck/tensor_operation/gpu/grid/normalization/gridwise_normalization_welford_variance.hpp index 3a7ae459e5f6e8d58edc49ae53c4309030e911f9..ff9712276c26d14c9a68f9df618ed4666f2dc5e9 100644 --- a/include/ck/tensor_operation/gpu/grid/normalization/gridwise_normalization_welford_variance.hpp +++ b/include/ck/tensor_operation/gpu/grid/normalization/gridwise_normalization_welford_variance.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/include/ck/tensor_operation/gpu/thread/reduction_functions_threadwise.hpp b/include/ck/tensor_operation/gpu/thread/reduction_functions_threadwise.hpp index 188c62d93b0d60317ac139115b79f496612d2ea6..c6eecc067d9f77fa488e7a922fe1fb711c0c0d0f 100644 --- a/include/ck/tensor_operation/gpu/thread/reduction_functions_threadwise.hpp +++ b/include/ck/tensor_operation/gpu/thread/reduction_functions_threadwise.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/include/ck/tensor_operation/gpu/thread/threadwise_contraction_dl.hpp b/include/ck/tensor_operation/gpu/thread/threadwise_contraction_dl.hpp index 94cdfe01087bc4005d45f11992a6a8f94d42691c..44730d551c5c8c9fb2408502ece0d45c6f37a221 100644 --- a/include/ck/tensor_operation/gpu/thread/threadwise_contraction_dl.hpp +++ b/include/ck/tensor_operation/gpu/thread/threadwise_contraction_dl.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/include/ck/tensor_operation/gpu/thread/threadwise_gemm_dlops_v3.hpp b/include/ck/tensor_operation/gpu/thread/threadwise_gemm_dlops_v3.hpp index e045e3b545a9897f5df9df1948123e86424d7c7b..e97aa433a6f3b9d54fe76ed26b5b0f6bd439ae84 100644 --- a/include/ck/tensor_operation/gpu/thread/threadwise_gemm_dlops_v3.hpp +++ b/include/ck/tensor_operation/gpu/thread/threadwise_gemm_dlops_v3.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #ifndef CK_THREADWISE_GEMM_DLOPS_V3_HPP #define CK_THREADWISE_GEMM_DLOPS_V3_HPP diff --git a/include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_set.hpp b/include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_set.hpp index 0a1197a1630c12c53afa9513f10fcfa907ba82d3..6774a35bcb872466d272da0bccd618709d3fb29b 100644 --- a/include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_set.hpp +++ b/include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_set.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp b/include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp index 16484ddcc5a89792c1331a48516e8317b36f2633..605f2569c6622bb589d46da71020ea46ddba69f7 100644 --- a/include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp +++ b/include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer_v3r1.hpp b/include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer_v3r1.hpp index 6665d765f81ccd45f36598b104711e77368b6ee5..32ea8ae39c60c0e2130ab9381d94ecd8ea399982 100644 --- a/include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer_v3r1.hpp +++ b/include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer_v3r1.hpp @@ -129,6 +129,9 @@ struct ThreadwiseTensorSliceTransfer_v3r1 constexpr auto src_access_lengths = SliceLengths{} / src_scalar_per_access; + static_assert(SliceLengths::At(SrcVectorDim) % SrcScalarPerVector == 0, + "SliceLengths[SrcVectorDim] must be divisible by SrcScalarPerVector"); + constexpr auto src_dim_access_order = SrcDimAccessOrder{}; constexpr auto ordered_src_access_lengths = diff --git a/include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer_v4r1.hpp b/include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer_v4r1.hpp index 6e8a23930bbb91677ee18bab216af6c45de72e4c..6a6c1f2ac5bfed7f02c5f5895d16002226cff56c 100644 --- a/include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer_v4r1.hpp +++ b/include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer_v4r1.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer_v5r1.hpp b/include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer_v5r1.hpp index f13da341f9b2d36d992b1cd1835b43c89dcabd1f..bd01108b03ca82d4877e19806521efda8da8c6c3 100644 --- a/include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer_v5r1.hpp +++ b/include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer_v5r1.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer_v6r1.hpp b/include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer_v6r1.hpp index 9c91cd9ca8f86df37fee6dfdf597be1cefe2d683..644877d3931f08ea46b017446140484de8746fa9 100644 --- a/include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer_v6r1.hpp +++ b/include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer_v6r1.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #pragma once @@ -104,13 +104,13 @@ struct ThreadwiseTensorSliceTransfer_v6r1 // apply pointwise operation static_for<0, ScalarPerVector, 1>{}([&](auto i) { - SrcData v; + DstData v; // apply element-wise operation element_op_(v, src_vector_container.template AsType()[i]); // apply type convert - dst_vector_container.template AsType()(i) = type_convert(v); + dst_vector_container.template AsType()(i) = v; }); const bool is_dst_valid = diff --git a/include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer_v6r1r2.hpp b/include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer_v6r1r2.hpp new file mode 100644 index 0000000000000000000000000000000000000000..88ed217547088b685a5291026190c74db9305da1 --- /dev/null +++ b/include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer_v6r1r2.hpp @@ -0,0 +1,213 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck/utility/common_header.hpp" +#include "ck/tensor_description/tensor_descriptor.hpp" +#include "ck/tensor_description/tensor_descriptor_helper.hpp" +#include "ck/tensor_description/tensor_space_filling_curve.hpp" + +namespace ck { + +// Do following things to avoid "alloca" in LLVM-IR, which would cause scratch memory +// and sometimes useless instructions: +// 1. Don't save a reference to tensor descriptor in class, pass in tensor descriptor as argument +// instead +// 2. Don't construct a new tensor coordinate everytime when using it, update and reuse the same +// tensor coordinate instead +// 3. Don't use a pointer to VGPR buffer, use vector instead + +// Assume: +// 1. src_desc and dst_desc are not known at compile-time +// 2. SrcBuffer and DstBuffer are DynamicBuffer +// 3. src_slice_origin and dst_slice_origin are not known at compile-time, +template +struct ThreadwiseTensorSliceTransfer_v6r1r2 +{ + static constexpr index_t nDim = SliceLengths::Size(); + + using Index = MultiIndex; + + using SrcCoord = decltype(make_tensor_coordinate(SrcDesc{}, Index{})); + using DstCoord = decltype(make_tensor_coordinate(DstDesc{}, Index{})); + + static constexpr auto I0 = Number<0>{}; + + __device__ constexpr ThreadwiseTensorSliceTransfer_v6r1r2( + const SrcDesc& src_desc, + const Index& src_slice_origin, + const DstDesc& dst_desc, + const Index& dst_slice_origin, + const ElementwiseOperation& element_op) + : src_coord_(make_tensor_coordinate(src_desc, src_slice_origin)), + dst_coord_(make_tensor_coordinate(dst_desc, dst_slice_origin)), + element_op_(element_op) + { + static_assert(SliceLengths::At(Number{}) % ScalarPerVector == 0, + "wrong! cannot evenly divide"); + } + + __device__ void SetSrcSliceOrigin(const SrcDesc& src_desc, const Index& src_slice_origin_idx) + { + src_coord_ = make_tensor_coordinate(src_desc, src_slice_origin_idx); + } + + __device__ void SetDstSliceOrigin(const DstDesc& dst_desc, const Index& dst_slice_origin_idx) + { + dst_coord_ = make_tensor_coordinate(dst_desc, dst_slice_origin_idx); + } + + template + __device__ void Run(const SrcDesc& src_desc, + const SrcBuffer& src_buf, + const DstDesc& dst_desc, + DstBuffer& dst_buf) + { + // scalar per access on each dim + // TODO: don't use lambda_scalar_per_access + constexpr auto scalar_per_access = generate_sequence( + detail::lambda_scalar_per_access{}, Number{}); + + using SpaceFillingCurve = SpaceFillingCurve>; + + // loop over space-filling curve + constexpr auto num_access = SpaceFillingCurve::GetNumOfAccess(); + + static_for<0, num_access, 1>{}([&](auto idx_1d) { + using src_vector_type = vector_type_maker_t; + using src_vector_t = typename src_vector_type::type; + + using dst_vector_type = vector_type_maker_t; + using dst_vector_t = typename dst_vector_type::type; + + const bool is_src_valid = + coordinate_has_valid_offset_assuming_visible_index_is_valid(src_desc, src_coord_); + + // copy data from src_buf into src_vector_container + auto src_vector_container = src_vector_type{ + src_buf.template Get(src_coord_.GetOffset(), is_src_valid)}; + + auto dst_vector_container = dst_vector_type{}; + + // apply pointwise operation + static_for<0, ScalarPerVector, 1>{}([&](auto i) { + SrcData v; + + // apply element-wise operation + element_op_(v, src_vector_container.template AsType()[i]); + + // apply type convert + dst_vector_container.template AsType()(i) = type_convert(v); + }); + + const bool is_dst_valid = + coordinate_has_valid_offset_assuming_visible_index_is_valid(dst_desc, dst_coord_); + + // copy data from dst_vector into dst_buf + dst_buf.template Update( + dst_coord_.GetOffset(), + is_dst_valid, + dst_vector_container.template AsType()[I0]); + + // move coordinate + if constexpr(idx_1d.value != num_access - 1) + { + constexpr auto forward_step = SpaceFillingCurve::GetForwardStep(idx_1d); + move_tensor_coordinate( + src_desc, src_coord_, make_tensor_coordinate_step(src_desc, forward_step)); + move_tensor_coordinate( + dst_desc, dst_coord_, make_tensor_coordinate_step(dst_desc, forward_step)); + } + }); + + // move coordinate back to slice origin (or not) + if constexpr(SrcResetCoordinateAfterRun) + { + const auto src_reset_step = + make_tensor_coordinate_step(src_desc, GetCoordinateResetStep()); + + move_tensor_coordinate(src_desc, src_coord_, src_reset_step); + } + + if constexpr(DstResetCoordinateAfterRun) + { + const auto dst_reset_step = + make_tensor_coordinate_step(dst_desc, GetCoordinateResetStep()); + + move_tensor_coordinate(dst_desc, dst_coord_, dst_reset_step); + } + } + + __device__ static constexpr auto GetCoordinateResetStep() + { + constexpr auto scalar_per_access = generate_sequence( + detail::lambda_scalar_per_access{}, Number{}); + + using SpaceFillingCurve = SpaceFillingCurve>; + + constexpr auto num_access = SpaceFillingCurve::GetNumOfAccess(); + if constexpr(num_access == 0) + { + return typename SpaceFillingCurve::Index{}; + } + else + { + constexpr auto reset_step = + SpaceFillingCurve::GetStepBetween(Number{}, Number<0>{}); + + return reset_step; + } + } + + // src_slice_origin_step_idx need to be known at compile-time, for performance reason + __device__ void MoveSrcSliceWindow(const SrcDesc& src_desc, + const Index& src_slice_origin_step_idx) + { + // if src coord was not reset by RunRead(), then need to adjust the step here + const auto adjusted_step_idx = SrcResetCoordinateAfterRun + ? src_slice_origin_step_idx + : src_slice_origin_step_idx + GetCoordinateResetStep(); + + // is it OK to construct a new step every time? + const auto adjusted_step = make_tensor_coordinate_step(src_desc, adjusted_step_idx); + + move_tensor_coordinate(src_desc, src_coord_, adjusted_step); + } + + // dst_slice_origin_step_idx need to be known at compile-time, for performance reason + __device__ void MoveDstSliceWindow(const DstDesc& dst_desc, + const Index& dst_slice_origin_step_idx) + { + // if dst coord was not reset by Run(), then need to adjust the step here + const auto adjusted_step_idx = DstResetCoordinateAfterRun + ? dst_slice_origin_step_idx + : dst_slice_origin_step_idx + GetCoordinateResetStep(); + + // is it OK to construct a new step every time? + const auto adjusted_step = make_tensor_coordinate_step(dst_desc, adjusted_step_idx); + + move_tensor_coordinate(dst_desc, dst_coord_, adjusted_step); + } + + private: + SrcCoord src_coord_; + DstCoord dst_coord_; + const ElementwiseOperation element_op_; +}; + +} // namespace ck diff --git a/include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer_v6r2.hpp b/include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer_v6r2.hpp index 68bc2726f4be6fa1b08fe9213e0db9440526f975..cf2c7a2aee3d3f612fd12c0868477bfdfe49a2ea 100644 --- a/include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer_v6r2.hpp +++ b/include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer_v6r2.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer_v6r3.hpp b/include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer_v6r3.hpp index 0f5fb88b04540dbf355fb0f02998303b7680c193..b5847e51b42ea5b567133f0307b319fa3d8fadbb 100644 --- a/include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer_v6r3.hpp +++ b/include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer_v6r3.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer_v7.hpp b/include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer_v7.hpp index 2eb1b0ee90a446ee3b14354aff018275cf7809f4..db7dee21992d0e4edaa2e7a443b5c0ed68589107 100644 --- a/include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer_v7.hpp +++ b/include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer_v7.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/include/ck/tensor_operation/gpu/thread/threadwise_welford.hpp b/include/ck/tensor_operation/gpu/thread/threadwise_welford.hpp index 12ba2c5381311ab45ffad4be793651d176f5fc59..eb6715e8ebb382e438791f60193bfc23f06d1773 100644 --- a/include/ck/tensor_operation/gpu/thread/threadwise_welford.hpp +++ b/include/ck/tensor_operation/gpu/thread/threadwise_welford.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/include/ck/tensor_operation/gpu/warp/dpp_gemm.hpp b/include/ck/tensor_operation/gpu/warp/dpp_gemm.hpp new file mode 100644 index 0000000000000000000000000000000000000000..03a4d17c9b6165fca2551c91fbd6ab26929ca2e9 --- /dev/null +++ b/include/ck/tensor_operation/gpu/warp/dpp_gemm.hpp @@ -0,0 +1,322 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck/utility/amd_gemm_dpp.hpp" +#include "ck/utility/common_header.hpp" +#include "ck/utility/math.hpp" + +namespace ck { + +enum struct DppInstr +{ + dpp8_f16_16x16x2 = 0, + dpp8_f16_8x32x2, + dpp8_f16_32x8x2 +}; + +/** + * Structure representing DPP GEMM executed by a single wavefront. + * + * Each structure instantiation must contain the following fields: + * - wave_size - number of threads that execute a single DPP GEMM operation, usually equal to the + * number of threads in a wavefront; + * - lanegroup_size - number of threads (lanes) that share data using DPP instruction modifier, + * it's 8 in case of DPP8; + * - m_per_wave - size along M dimension of matrix C that is processed in a single DPP GEMM + * operation; + * - n_per_wave - size along N dimension of matrix C that is processed in a single DPP GEMM + * operation; + * - m_per_lanegroup - size along M dimension that is processed by a single lanegroup; + * - n_per_lanegroup - size along N dimension that is processed by a single lanegroup; + * - m_per_thread - size along M dimension of the tile calculated by a single thread; + * - n_per_thread - size along N dimension of the tile calculated by a single thread; + * - k_per_dpp - size along K dimension that is reduced in a single DPP GEMM operation; + * - share_a - indicates whether we share matrix A or matrix B between lanes using DPP modifiers. + * + * Not all the combinarions are supported now, for current restrictions see the static asserts + * in the DppSelector's contructor. + */ +template +struct dpp_type; + +template <> +struct dpp_type +{ + static constexpr index_t wave_size = 32; + static constexpr index_t lanegroup_size = 8; + static constexpr index_t m_per_wave = 32; + static constexpr index_t n_per_wave = 8; + static constexpr index_t m_per_lanegroup = 8; + static constexpr index_t n_per_lanegroup = 8; + static constexpr index_t m_per_thread = 8; + static constexpr index_t n_per_thread = 1; + static constexpr index_t k_per_dpp = 2; + static constexpr bool share_a = true; + using BaseType = half_t; + + template + __device__ void run(const ADataType& a, const BDataType& b, CDataType& reg_c) const + { + dpp8::DppLanegroupGemm{} + .Run(a, b, reg_c); + } +}; + +template <> +struct dpp_type +{ + static constexpr index_t wave_size = 32; + static constexpr index_t lanegroup_size = 8; + static constexpr index_t m_per_wave = 8; + static constexpr index_t n_per_wave = 32; + static constexpr index_t m_per_lanegroup = 8; + static constexpr index_t n_per_lanegroup = 8; + static constexpr index_t m_per_thread = 8; + static constexpr index_t n_per_thread = 1; + static constexpr index_t k_per_dpp = 2; + static constexpr bool share_a = true; + using BaseType = half_t; + + template + __device__ void run(const ADataType& a, const BDataType& b, CDataType& reg_c) const + { + dpp8::DppLanegroupGemm{} + .Run(a, b, reg_c); + } +}; + +template <> +struct dpp_type +{ + static constexpr index_t wave_size = 32; + static constexpr index_t lanegroup_size = 8; + static constexpr index_t m_per_wave = 16; + static constexpr index_t n_per_wave = 16; + static constexpr index_t m_per_lanegroup = 8; + static constexpr index_t n_per_lanegroup = 8; + static constexpr index_t m_per_thread = 8; + static constexpr index_t n_per_thread = 1; + static constexpr index_t k_per_dpp = 2; + static constexpr bool share_a = true; + using BaseType = half_t; + + template + __device__ void run(const ADataType& a, const BDataType& b, CDataType& reg_c) const + { + dpp8::DppLanegroupGemm{} + .Run(a, b, reg_c); + } +}; + +template +struct DppSelector +{ + template + static constexpr auto GetDpp(); + + template <> + static constexpr auto GetDpp() + { + return DppInstr::dpp8_f16_8x32x2; + } + + template <> + static constexpr auto GetDpp() + { + return DppInstr::dpp8_f16_16x16x2; + } + + template <> + static constexpr auto GetDpp() + { + return DppInstr::dpp8_f16_32x8x2; + } + + static constexpr auto selected_dpp = dpp_type()>{}; + + __host__ __device__ constexpr DppSelector() + { + static_assert(selected_dpp.m_per_wave % selected_dpp.m_per_lanegroup == 0); + static_assert(selected_dpp.n_per_wave % selected_dpp.n_per_lanegroup == 0); + + static_assert(selected_dpp.k_per_dpp % 2 == 0); + + static_assert(selected_dpp.wave_size % selected_dpp.lanegroup_size == 0); + constexpr index_t num_dpp_per_wave = selected_dpp.wave_size / selected_dpp.lanegroup_size; + constexpr index_t num_wave_c_elems = selected_dpp.m_per_wave * selected_dpp.n_per_wave; + constexpr index_t num_dpp_c_elems = + selected_dpp.m_per_lanegroup * selected_dpp.n_per_lanegroup; + static_assert(num_wave_c_elems % num_dpp_c_elems == 0); + static_assert(num_dpp_per_wave == num_wave_c_elems / num_dpp_c_elems); + + if constexpr(selected_dpp.share_a) + { + static_assert(selected_dpp.m_per_lanegroup == selected_dpp.m_per_thread); + static_assert(selected_dpp.n_per_lanegroup % selected_dpp.n_per_thread == 0); + static_assert(selected_dpp.n_per_lanegroup / selected_dpp.n_per_thread == + selected_dpp.lanegroup_size); + } + else + { + static_assert(selected_dpp.m_per_lanegroup % selected_dpp.n_per_thread == 0); + static_assert(selected_dpp.m_per_lanegroup / selected_dpp.n_per_thread == + selected_dpp.lanegroup_size); + static_assert(selected_dpp.n_per_lanegroup == selected_dpp.n_per_thread); + } + + // Below checks come from the restrictions of the current implementation, could be removed + // in the future when the implementation is more generalized. + static_assert(selected_dpp.share_a); + static_assert(selected_dpp.n_per_thread == 1); + static_assert(selected_dpp.m_per_thread == selected_dpp.lanegroup_size); + static_assert(selected_dpp.m_per_lanegroup == selected_dpp.m_per_thread); + static_assert(selected_dpp.n_per_lanegroup == + selected_dpp.n_per_thread * selected_dpp.lanegroup_size); + } + + static constexpr index_t GetK1PerDpp() { return selected_dpp.k_per_dpp; } +}; + +template +struct DppGemm +{ + static constexpr auto I0 = Number<0>{}; + static constexpr auto I1 = Number<1>{}; + static constexpr auto I2 = Number<2>{}; + static constexpr auto I3 = Number<3>{}; + static constexpr auto I4 = Number<4>{}; + static constexpr auto I5 = Number<5>{}; + + using CIndex = MultiIndex<2>; + using CIndex4D = MultiIndex<4>; + + __host__ __device__ constexpr DppGemm() + { + static_assert(MPerDpp == 8 || MPerDpp == 16 || MPerDpp == 32, + "MPerDpp must be either 8, 16 or 32."); + static_assert(NPerDpp == 8 || NPerDpp == 16 || NPerDpp == 32, + "NPerDpp must be either 8, 16 or 32."); + + static_assert(KPack % dpp_instr.k_per_dpp == 0, "KPack must be divisible by k_per_dpp."); + } + + __device__ static constexpr index_t GetRegSizePerDpp() + { + return MPerDpp * NPerDpp / dpp_instr.wave_size; + } + + template + __device__ void + Run(const ADataType& p_a_wave, const BDataType& p_b_wave, CDataType& p_c_thread) const + { + static_assert(is_same::value || is_same::value || + is_same::value || is_same::value || + is_same::value || is_same::value, + "base BaseType must be double, float, half, bfloat16, and int8_t!"); + + static_for<0, KPack / dpp_instr.k_per_dpp, 1>{}([&](auto k) { + dpp_instr.template run(p_a_wave[k], p_b_wave[k], p_c_thread); + }); + } + + __device__ static auto GetLaneIdInWave() + { + return get_thread_local_1d_id() % dpp_instr.wave_size; + } + + __device__ static auto GetWaveId() { return get_thread_local_1d_id() / dpp_instr.wave_size; } + + __device__ static auto GetLaneIdInLaneGroup() + { + return get_thread_local_1d_id() % dpp_instr.lanegroup_size; + } + + __device__ static auto GetLaneGroupIdInWave() + { + return GetLaneIdInWave() / dpp_instr.lanegroup_size; + } + + __device__ static auto GetDppOpIdx() + { + const auto lanegroupId = GetLaneGroupIdInWave(); + + constexpr auto lanegroup_idx_1d_to_dpp_idx_2d_adaptor = make_single_stage_tensor_adaptor( + make_tuple( + make_merge_transform(make_tuple(dpp_instr.m_per_wave / dpp_instr.m_per_lanegroup, + dpp_instr.n_per_wave / dpp_instr.n_per_lanegroup))), + make_tuple(Sequence<0, 1>{}), + make_tuple(Sequence<0>{})); + + const auto dpp_idx = lanegroup_idx_1d_to_dpp_idx_2d_adaptor.CalculateBottomIndex( + make_multi_index(lanegroupId)); + + const auto m_dpp_idx = dpp_idx[I0]; + const auto n_dpp_idx = dpp_idx[I1]; + + return make_tuple(m_dpp_idx, n_dpp_idx); + } + + __host__ __device__ static auto CalculateAThreadOriginDataIndex_K_M() + { + const auto laneId = get_thread_local_1d_id(); + const auto wave_row = laneId / dpp_instr.n_per_wave; + auto m_idx = dpp_instr.m_per_thread * wave_row + GetLaneIdInLaneGroup(); + return make_tuple(0, m_idx % dpp_instr.m_per_wave); + } + + __host__ __device__ static auto CalculateBThreadOriginDataIndex_K_N() + { + const auto laneId = get_thread_local_1d_id(); + return make_tuple(0, laneId % dpp_instr.n_per_wave); + } + + __device__ static CIndex GetBeginOfThreadBlk() + { + const auto dpp_op_idx = GetDppOpIdx(); + + const auto m_dpp_op_idx = dpp_op_idx[I0]; + const auto n_dpp_op_idx = dpp_op_idx[I1]; + + index_t n_offset = n_dpp_op_idx * dpp_instr.n_per_lanegroup + GetLaneIdInLaneGroup(); + index_t m_offset = m_dpp_op_idx * dpp_instr.m_per_lanegroup; + + return CIndex{m_offset, n_offset}; + } + + static constexpr auto dpp = DppSelector{}; + + static constexpr auto dpp_instr = dpp.selected_dpp; + + static constexpr auto K0PerDpp = 1; + static constexpr auto K1PerDpp = dpp.GetK1PerDpp(); + + __host__ __device__ static constexpr auto GetCMNThreadBlkLengths() + { + return make_tuple(Number{}, Number{}); + } +}; + +} // namespace ck diff --git a/include/ck/tensor_operation/gpu/warp/wmma_gemm.hpp b/include/ck/tensor_operation/gpu/warp/wmma_gemm.hpp index 24efeb2dea7ba1e18f91524014b50d819cdfd785..979f3567e987ed66fd45af5deeae7bb2ac3fc761 100644 --- a/include/ck/tensor_operation/gpu/warp/wmma_gemm.hpp +++ b/include/ck/tensor_operation/gpu/warp/wmma_gemm.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/include/ck/tensor_operation/gpu/warp/xdlops_gemm.hpp b/include/ck/tensor_operation/gpu/warp/xdlops_gemm.hpp index 319487bc053ccb7ae4598ca467836fd07196d7c8..814969ef42baa2cc8e1f4188f61b3d50d74ea265 100644 --- a/include/ck/tensor_operation/gpu/warp/xdlops_gemm.hpp +++ b/include/ck/tensor_operation/gpu/warp/xdlops_gemm.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #pragma once @@ -29,7 +29,9 @@ enum struct MfmaInstr mfma_i32_16x16x16i8, mfma_i32_32x32x16i8, mfma_i32_16x16x32i8, - mfma_f64_16x16x4f64 + mfma_f64_16x16x4f64, + mfma_f32_32x32x16f8f8, + mfma_f32_16x16x32f8f8 }; template @@ -454,6 +456,50 @@ struct mfma_type } }; +template <> +struct mfma_type +{ + static constexpr index_t group_size = 4; + static constexpr index_t num_groups_per_blk = 4; + static constexpr index_t num_regs_per_blk = 16; + static constexpr index_t num_threads_per_blk = 32; + static constexpr index_t wave_size = 64; + static constexpr index_t num_input_blks = 2; + static constexpr index_t num_output_blks = 1; + static constexpr index_t m_per_blk = 32; + static constexpr index_t n_per_blk = 32; + static constexpr index_t k_per_blk = 8; + static constexpr bool is_k_reduction = true; + + template + __device__ void run(const FloatA& a, const FloatB& b, FloatC& reg_c) const + { + intrin_mfma_f32_32x32x16f8f8::Run(a, b, reg_c); + } +}; + +template <> +struct mfma_type +{ + static constexpr index_t group_size = 4; + static constexpr index_t num_groups_per_blk = 1; + static constexpr index_t num_regs_per_blk = 4; + static constexpr index_t num_threads_per_blk = 16; + static constexpr index_t wave_size = 64; + static constexpr index_t num_input_blks = 4; + static constexpr index_t num_output_blks = 1; + static constexpr index_t m_per_blk = 16; + static constexpr index_t n_per_blk = 16; + static constexpr index_t k_per_blk = 8; + static constexpr bool is_k_reduction = true; + + template + __device__ void run(const FloatA& a, const FloatB& b, FloatC& reg_c) const + { + intrin_mfma_f32_16x16x32f8f8::Run(a, b, reg_c); + } +}; + template struct MfmaSelector { @@ -594,6 +640,18 @@ struct MfmaSelector } #endif + template <> + static constexpr auto GetMfma() + { + return MfmaInstr::mfma_f32_32x32x16f8f8; + } + + template <> + static constexpr auto GetMfma() + { + return MfmaInstr::mfma_f32_16x16x32f8f8; + } + static constexpr auto selected_mfma = mfma_type()>{}; __host__ __device__ constexpr MfmaSelector() @@ -794,7 +852,7 @@ struct XdlopsGemm { static_assert(is_same::value || is_same::value || is_same::value || is_same::value || - is_same::value, + is_same::value || is_same::value, "base base_type must be double, float, half, bfloat16, and int8_t!"); static_for<0, KPack / mfma_instr.k_per_blk, 1>{}([&](auto k) { diff --git a/include/ck/tensor_operation/operator_transform/transform_contraction_to_gemm.hpp b/include/ck/tensor_operation/operator_transform/transform_contraction_to_gemm.hpp index 5fc11d9158ab9a1261c1285c2cd9f5208d7afbf3..ea27a40ce3c7a48564af112a97af342c157f746d 100644 --- a/include/ck/tensor_operation/operator_transform/transform_contraction_to_gemm.hpp +++ b/include/ck/tensor_operation/operator_transform/transform_contraction_to_gemm.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/include/ck/tensor_operation/operator_transform/transform_conv_bwd_data_to_gemm_v1.hpp b/include/ck/tensor_operation/operator_transform/transform_conv_bwd_data_to_gemm_v1.hpp index 13d0a28cfe5b2db14a83a4b81e0bb379aabbe137..2be0b66812434eb5127f5a97534ca4dc36b68273 100644 --- a/include/ck/tensor_operation/operator_transform/transform_conv_bwd_data_to_gemm_v1.hpp +++ b/include/ck/tensor_operation/operator_transform/transform_conv_bwd_data_to_gemm_v1.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #pragma once @@ -13,6 +13,150 @@ namespace ck { namespace tensor_operation { +namespace { +template < + index_t NDimSpatial, + typename ALayout, + ck::tensor_operation::device::ConvolutionBackwardDataSpecialization ConvBwdDataSpecialization> +constexpr auto make_out_grid_desc(const index_t N, + const index_t Do, + const index_t Ho, + const index_t Wo, + const index_t K, + const std::array& out_g_n_k_wos_strides) +{ + const auto KStride = Number<1>{}; + + if constexpr(is_same_v) + { + const index_t NStride = out_g_n_k_wos_strides[1]; + const index_t HiStride = out_g_n_k_wos_strides[3]; + const index_t WiStride = out_g_n_k_wos_strides[4]; + if constexpr(ConvBwdDataSpecialization == + ck::tensor_operation::device::ConvolutionBackwardDataSpecialization:: + Filter1x1Stride1Pad0) + { + + return make_naive_tensor_descriptor(make_tuple(N * Ho * Wo, K), + make_tuple(WiStride, KStride)); + } + else + { + return make_naive_tensor_descriptor(make_tuple(N, Ho, Wo, K), + make_tuple(NStride, HiStride, WiStride, KStride)); + } + } + else if constexpr(is_same_v) + { + const index_t NStride = out_g_n_k_wos_strides[1]; + const index_t DoStride = out_g_n_k_wos_strides[3]; + const index_t HoStride = out_g_n_k_wos_strides[4]; + const index_t WoStride = out_g_n_k_wos_strides[5]; + if constexpr(ConvBwdDataSpecialization == + ck::tensor_operation::device::ConvolutionBackwardDataSpecialization:: + Filter1x1Stride1Pad0) + { + + return make_naive_tensor_descriptor(make_tuple(N * Do * Ho * Wo, K), + make_tuple(WoStride, KStride)); + } + else + { + return make_naive_tensor_descriptor( + make_tuple(N, Do, Ho, Wo, K), + make_tuple(NStride, DoStride, HoStride, WoStride, KStride)); + } + } + else if constexpr(is_same_v) + { + // assume packed + if constexpr(ConvBwdDataSpecialization == + ck::tensor_operation::device::ConvolutionBackwardDataSpecialization:: + Filter1x1Stride1Pad0) + { + return make_naive_tensor_descriptor_packed(make_tuple(N * Ho * Wo, K)); + } + else + { + return make_naive_tensor_descriptor_packed(make_tuple(N, Ho, Wo, K)); + } + } + else if constexpr(is_same_v) + { + // assume packed + if constexpr(ConvBwdDataSpecialization == + ck::tensor_operation::device::ConvolutionBackwardDataSpecialization:: + Filter1x1Stride1Pad0) + { + return make_naive_tensor_descriptor_packed(make_tuple(N * Do * Ho * Wo, K)); + } + else + { + return make_naive_tensor_descriptor_packed(make_tuple(N, Do, Ho, Wo, K)); + } + } + else + { + throw std::runtime_error("wrong! unsupported layout: " + ALayout::name()); + } +} + +template +constexpr auto make_wei_grid_desc( + const index_t K, const index_t Z, const index_t Y, const index_t X, const index_t C) +{ + + if constexpr(is_same_v) + { + return make_naive_tensor_descriptor_packed(make_tuple(K, Y, X, C)); + } + else if constexpr(is_same_v) + { + return make_naive_tensor_descriptor_packed(make_tuple(K, Z, Y, X, C)); + } + else + { + throw std::runtime_error("wrong! unsupported layout: " + BLayout::name()); + } +} + +template +constexpr auto make_in_grid_desc(const index_t N, + const index_t Di, + const index_t Hi, + const index_t Wi, + const index_t C, + const std::array& in_g_n_c_wis_strides) +{ + + if constexpr(is_same_v || + is_same_v || + is_same_v) + { + return make_naive_tensor_descriptor(make_tuple(N, Hi, Wi, C), + make_tuple(in_g_n_c_wis_strides[1], + in_g_n_c_wis_strides[3], + in_g_n_c_wis_strides[4], + in_g_n_c_wis_strides[2])); + } + else if constexpr(is_same_v || + is_same_v) + { + return make_naive_tensor_descriptor(make_tuple(N, Di, Hi, Wi, C), + make_tuple(in_g_n_c_wis_strides[1], + in_g_n_c_wis_strides[3], + in_g_n_c_wis_strides[4], + in_g_n_c_wis_strides[5], + in_g_n_c_wis_strides[2])); + } + else + { + throw std::runtime_error("wrong! unsupported layout: " + CLayout::name()); + } +} + +} // namespace + template < index_t NDimSpatial, ck::tensor_operation::device::ConvolutionBackwardDataSpecialization ConvBwdDataSpecialization, @@ -20,6 +164,7 @@ template < index_t BK1, index_t GemmMPerBlock, index_t GemmNPerBlock, + index_t GemmKPerBlock, bool DoPadGemmM, bool DoPadGemmN> struct TransformConvBwdDataToGemm_v1 @@ -27,13 +172,30 @@ struct TransformConvBwdDataToGemm_v1 static constexpr auto I0 = Number<0>{}; static constexpr auto I1 = Number<1>{}; + static constexpr auto NonSpatialDimsNum = Number<3>{}; + + static constexpr auto DIdx = Number{}; + static constexpr auto HIdx = + NDimSpatial == 2 ? Number{} : Number{}; + static constexpr auto WIdx = + NDimSpatial == 2 ? Number{} : Number{}; + + static constexpr auto ZIdx = Number{}; + static constexpr auto YIdx = + NDimSpatial == 2 ? Number{} : Number{}; + static constexpr auto XIdx = + NDimSpatial == 2 ? Number{} : Number{}; + template , + typename std::enable_if<(NDimSpatial == 2 || NDimSpatial == 3) && + (is_same_v || + is_same_v || + is_same_v || + is_same_v), bool>::type = false> static auto MakeADescriptor_AK0_M_AK1( const std::array& out_g_n_k_wos_lengths, - const std::array& /* out_g_n_k_wos_strides */, + const std::array& out_g_n_k_wos_strides, const std::array& wei_g_k_c_xs_lengths, const std::array& /* wei_g_k_c_xs_strides */, const std::array& in_g_n_c_wis_lengths, @@ -44,44 +206,52 @@ struct TransformConvBwdDataToGemm_v1 const std::array& /* input_right_pads */, const std::array& tildes) { - index_t i_ytilde = tildes[0]; - index_t i_xtilde = tildes[1]; + index_t i_ztilde = tildes[ZIdx - NonSpatialDimsNum]; + index_t i_ytilde = tildes[YIdx - NonSpatialDimsNum]; + index_t i_xtilde = tildes[XIdx - NonSpatialDimsNum]; const index_t N = in_g_n_c_wis_lengths[1]; const index_t K = wei_g_k_c_xs_lengths[1]; - const index_t Hi = in_g_n_c_wis_lengths[3]; - const index_t Wi = in_g_n_c_wis_lengths[4]; - - const index_t Ho = out_g_n_k_wos_lengths[3]; - const index_t Wo = out_g_n_k_wos_lengths[4]; + const index_t Di = NDimSpatial == 3 ? in_g_n_c_wis_lengths[DIdx] : 1; + const index_t Hi = in_g_n_c_wis_lengths[HIdx]; + const index_t Wi = in_g_n_c_wis_lengths[WIdx]; - const index_t Y = wei_g_k_c_xs_lengths[3]; - const index_t X = wei_g_k_c_xs_lengths[4]; + const index_t Do = NDimSpatial == 3 ? out_g_n_k_wos_lengths[DIdx] : 1; + const index_t Ho = out_g_n_k_wos_lengths[HIdx]; + const index_t Wo = out_g_n_k_wos_lengths[WIdx]; - const index_t InLeftPadH = input_left_pads[0]; - const index_t InLeftPadW = input_left_pads[1]; + const index_t Z = NDimSpatial == 3 ? wei_g_k_c_xs_lengths[ZIdx] : 1; + const index_t Y = wei_g_k_c_xs_lengths[YIdx]; + const index_t X = wei_g_k_c_xs_lengths[XIdx]; - const index_t ConvStrideH = conv_filter_strides[0]; - const index_t ConvStrideW = conv_filter_strides[1]; + const index_t InLeftPadD = input_left_pads[DIdx - NonSpatialDimsNum]; + const index_t InLeftPadH = input_left_pads[HIdx - NonSpatialDimsNum]; + const index_t InLeftPadW = input_left_pads[WIdx - NonSpatialDimsNum]; - const index_t ConvDilationH = conv_filter_dilations[0]; - const index_t ConvDilationW = conv_filter_dilations[1]; + const index_t ConvStrideD = conv_filter_strides[DIdx - NonSpatialDimsNum]; + const index_t ConvStrideH = conv_filter_strides[HIdx - NonSpatialDimsNum]; + const index_t ConvStrideW = conv_filter_strides[WIdx - NonSpatialDimsNum]; - const index_t AK0 = K / AK1; + const index_t ConvDilationD = conv_filter_dilations[DIdx - NonSpatialDimsNum]; + const index_t ConvDilationH = conv_filter_dilations[HIdx - NonSpatialDimsNum]; + const index_t ConvDilationW = conv_filter_dilations[WIdx - NonSpatialDimsNum]; - // assume packed - const auto out_n_ho_wo_k_grid_desc = - make_naive_tensor_descriptor_packed(make_tuple(N, Ho, Wo, K)); + // n_do_ho_wo_k for 3d or n_ho_wo_k for 2d + const auto out_grid_desc = + make_out_grid_desc( + N, Do, Ho, Wo, K, out_g_n_k_wos_strides); if constexpr(ConvBwdDataSpecialization == ck::tensor_operation::device::ConvolutionBackwardDataSpecialization:: Filter1x1Stride1Pad0) { + const index_t AK0 = math::integer_divide_ceil(K, AK1); + // A: output tensor const auto out_gemmak0_gemmmraw_gemmak1_grid_desc = transform_tensor_descriptor( - make_naive_tensor_descriptor_packed(make_tuple(N * Ho * Wo, K)), - make_tuple(make_pass_through_transform(N * Ho * Wo), + out_grid_desc, + make_tuple(make_pass_through_transform(N * Do * Ho * Wo), make_unmerge_transform(make_tuple(AK0, AK1))), make_tuple(Sequence<0>{}, Sequence<1>{}), make_tuple(Sequence<1>{}, Sequence<0, 2>{})); @@ -96,103 +266,226 @@ struct TransformConvBwdDataToGemm_v1 } else { + const auto GcdStrideDilationD = math::gcd(ConvStrideD, ConvDilationD); const auto GcdStrideDilationH = math::gcd(ConvStrideH, ConvDilationH); const auto GcdStrideDilationW = math::gcd(ConvStrideW, ConvDilationW); + const auto ZTilde = ConvStrideD / GcdStrideDilationD; const auto YTilde = ConvStrideH / GcdStrideDilationH; const auto XTilde = ConvStrideW / GcdStrideDilationW; + const auto ZDot = math::integer_divide_ceil(Z, ZTilde); const auto YDot = math::integer_divide_ceil(Y, YTilde); const auto XDot = math::integer_divide_ceil(X, XTilde); + const auto DTilde = + Do + math::integer_divide_ceil(ConvDilationD * (Z - I1), ConvStrideD); const auto HTilde = Ho + math::integer_divide_ceil(ConvDilationH * (Y - I1), ConvStrideH); const auto WTilde = Wo + math::integer_divide_ceil(ConvDilationW * (X - I1), ConvStrideW); // only work on HTilde and WTilde that contribute to non-padding area of input tensor + const auto IDTildeSliceBegin = math::integer_divide_floor( + math::max(I0, InLeftPadD - ConvDilationD * (ZTilde - I1)), ConvStrideD); const auto IHTildeSliceBegin = math::integer_divide_floor( math::max(I0, InLeftPadH - ConvDilationH * (YTilde - I1)), ConvStrideH); const auto IWTildeSliceBegin = math::integer_divide_floor( math::max(I0, InLeftPadW - ConvDilationW * (XTilde - I1)), ConvStrideW); + const auto IDTildeSliceEnd = math::min( + DTilde, math::integer_divide_ceil(InLeftPadD + Di - I1, ConvStrideD) + I1); const auto IHTildeSliceEnd = math::min( HTilde, math::integer_divide_ceil(InLeftPadH + Hi - I1, ConvStrideH) + I1); const auto IWTildeSliceEnd = math::min( WTilde, math::integer_divide_ceil(InLeftPadW + Wi - I1, ConvStrideW) + I1); + const auto DTildeSlice = IDTildeSliceEnd - IDTildeSliceBegin; const auto HTildeSlice = IHTildeSliceEnd - IHTildeSliceBegin; const auto WTildeSlice = IWTildeSliceEnd - IWTildeSliceBegin; // GemmK is different for each GEMM + const auto ZDotSlice = math::integer_divide_ceil(Z - i_ztilde, ZTilde); const auto YDotSlice = math::integer_divide_ceil(Y - i_ytilde, YTilde); const auto XDotSlice = math::integer_divide_ceil(X - i_xtilde, XTilde); - // A: output tensor - const auto out_n_hop_wop_k_grid_desc = transform_tensor_descriptor( - out_n_ho_wo_k_grid_desc, - make_tuple(make_pass_through_transform(N), - make_pad_transform(Ho, I0, I0), - make_pad_transform(Wo, I0, I0), - make_pass_through_transform(K)), - make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), - make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{})); - - const auto out_n_ydot_htilde_xdot_wtilde_k_grid_desc = transform_tensor_descriptor( - out_n_hop_wop_k_grid_desc, - make_tuple( - make_pass_through_transform(N), - make_embed_transform(make_tuple(YDot, HTilde), - make_tuple(-ConvDilationH / GcdStrideDilationH, I1)), - make_embed_transform(make_tuple(XDot, WTilde), - make_tuple(-ConvDilationW / GcdStrideDilationW, I1)), - make_pass_through_transform(K)), - make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), - make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<3, 4>{}, Sequence<5>{})); - - const auto out_n_ydotslice_htildeslice_xdotslice_wtildeslice_ak0_ak1_grid_desc = - transform_tensor_descriptor( - out_n_ydot_htilde_xdot_wtilde_k_grid_desc, + if constexpr(NDimSpatial == 2) + { + // A: output tensor + const auto out_n_hop_wop_k_grid_desc = transform_tensor_descriptor( + out_grid_desc, make_tuple(make_pass_through_transform(N), - make_slice_transform(YDot, I0, YDotSlice), - make_slice_transform(HTilde, IHTildeSliceBegin, HTildeSlice), - make_slice_transform(XDot, I0, XDotSlice), - make_slice_transform(WTilde, IWTildeSliceBegin, WTildeSlice), - make_unmerge_transform(make_tuple(AK0, AK1))), - make_tuple(Sequence<0>{}, - Sequence<1>{}, - Sequence<2>{}, - Sequence<3>{}, - Sequence<4>{}, - Sequence<5>{}), - make_tuple(Sequence<0>{}, - Sequence<1>{}, - Sequence<2>{}, - Sequence<3>{}, - Sequence<4>{}, - Sequence<5, 6>{})); - - const auto out_gemmak0_gemmmraw_gemmak1_grid_desc = transform_tensor_descriptor( - out_n_ydotslice_htildeslice_xdotslice_wtildeslice_ak0_ak1_grid_desc, - make_tuple(make_merge_transform(make_tuple(YDotSlice, XDotSlice, AK0)), - make_merge_transform(make_tuple(N, HTildeSlice, WTildeSlice)), - make_pass_through_transform(AK1)), - make_tuple(Sequence<1, 3, 5>{}, Sequence<0, 2, 4>{}, Sequence<6>{}), - make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{})); - - const auto out_gemmak0_gemmm_gemmak1_grid_desc = - ck::tensor_operation::device::PadTensorDescriptor( - out_gemmak0_gemmmraw_gemmak1_grid_desc, - make_tuple(AK0, GemmMPerBlock, AK1), - Sequence{}); - - return out_gemmak0_gemmm_gemmak1_grid_desc; + make_pad_transform(Ho, I0, I0), + make_pad_transform(Wo, I0, I0), + make_pass_through_transform(K)), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{})); + + const auto out_n_ydot_htilde_xdot_wtilde_k_grid_desc = transform_tensor_descriptor( + out_n_hop_wop_k_grid_desc, + make_tuple( + make_pass_through_transform(N), + make_embed_transform(make_tuple(YDot, HTilde), + make_tuple(-ConvDilationH / GcdStrideDilationH, I1)), + make_embed_transform(make_tuple(XDot, WTilde), + make_tuple(-ConvDilationW / GcdStrideDilationW, I1)), + make_pass_through_transform(K)), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), + make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<3, 4>{}, Sequence<5>{})); + + const auto out_n_ydotslice_htildeslice_xdotslice_wtildeslice_k_grid_desc = + transform_tensor_descriptor( + out_n_ydot_htilde_xdot_wtilde_k_grid_desc, + make_tuple(make_pass_through_transform(N), + make_slice_transform(YDot, I0, YDotSlice), + make_slice_transform(HTilde, IHTildeSliceBegin, HTildeSlice), + make_slice_transform(XDot, I0, XDotSlice), + make_slice_transform(WTilde, IWTildeSliceBegin, WTildeSlice), + make_pass_through_transform(K)), + make_tuple(Sequence<0>{}, + Sequence<1>{}, + Sequence<2>{}, + Sequence<3>{}, + Sequence<4>{}, + Sequence<5>{}), + make_tuple(Sequence<0>{}, + Sequence<1>{}, + Sequence<2>{}, + Sequence<3>{}, + Sequence<4>{}, + Sequence<5>{})); + + const auto out_gemmk_gemmmraw_grid_desc = transform_tensor_descriptor( + out_n_ydotslice_htildeslice_xdotslice_wtildeslice_k_grid_desc, + make_tuple(make_merge_transform(make_tuple(YDotSlice, XDotSlice, K)), + make_merge_transform(make_tuple(N, HTildeSlice, WTildeSlice))), + make_tuple(Sequence<1, 3, 5>{}, Sequence<0, 2, 4>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + + const auto out_gemmk_gemmm_padded_grid_desc = + ck::tensor_operation::device::PadTensorDescriptor( + out_gemmk_gemmmraw_grid_desc, + make_tuple(GemmKPerBlock, GemmMPerBlock), + Sequence{}); + + const index_t AK0 = out_gemmk_gemmm_padded_grid_desc.GetLength(I0) / AK1; + + const auto out_gemmak0_gemmm_gemmak1_grid_desc = transform_tensor_descriptor( + out_gemmk_gemmm_padded_grid_desc, + make_tuple(make_unmerge_transform(make_tuple(AK0, AK1)), + make_pass_through_transform( + out_gemmk_gemmm_padded_grid_desc.GetLength(I1))), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + + return out_gemmak0_gemmm_gemmak1_grid_desc; + } + else if constexpr(NDimSpatial == 3) + { + // A: output tensor + const auto out_n_hop_wop_k_grid_desc = transform_tensor_descriptor( + out_grid_desc, + make_tuple(make_pass_through_transform(N), + make_pad_transform(Do, I0, I0), + make_pad_transform(Ho, I0, I0), + make_pad_transform(Wo, I0, I0), + make_pass_through_transform(K)), + make_tuple( + Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}, Sequence<4>{}), + make_tuple( + Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}, Sequence<4>{})); + + const auto out_n_zdot_dtilde_ydot_htilde_xdot_wtilde_k_grid_desc = + transform_tensor_descriptor( + out_n_hop_wop_k_grid_desc, + make_tuple(make_pass_through_transform(N), + make_embed_transform( + make_tuple(ZDot, DTilde), + make_tuple(-ConvDilationD / GcdStrideDilationD, I1)), + make_embed_transform( + make_tuple(YDot, HTilde), + make_tuple(-ConvDilationH / GcdStrideDilationH, I1)), + make_embed_transform( + make_tuple(XDot, WTilde), + make_tuple(-ConvDilationW / GcdStrideDilationW, I1)), + make_pass_through_transform(K)), + make_tuple(Sequence<0>{}, + Sequence<1>{}, + Sequence<2>{}, + Sequence<3>{}, + Sequence<4>{}), + make_tuple(Sequence<0>{}, + Sequence<1, 2>{}, + Sequence<3, 4>{}, + Sequence<5, 6>{}, + Sequence<7>{})); + + const auto + out_n_zdotslice_dtildeslice_ydotslice_htildeslice_xdotslice_wtildeslice_k_grid_desc = + transform_tensor_descriptor( + out_n_zdot_dtilde_ydot_htilde_xdot_wtilde_k_grid_desc, + make_tuple(make_pass_through_transform(N), + make_slice_transform(ZDot, I0, ZDotSlice), + make_slice_transform(DTilde, IDTildeSliceBegin, DTildeSlice), + make_slice_transform(YDot, I0, YDotSlice), + make_slice_transform(HTilde, IHTildeSliceBegin, HTildeSlice), + make_slice_transform(XDot, I0, XDotSlice), + make_slice_transform(WTilde, IWTildeSliceBegin, WTildeSlice), + make_pass_through_transform(K)), + make_tuple(Sequence<0>{}, + Sequence<1>{}, + Sequence<2>{}, + Sequence<3>{}, + Sequence<4>{}, + Sequence<5>{}, + Sequence<6>{}, + Sequence<7>{}), + make_tuple(Sequence<0>{}, + Sequence<1>{}, + Sequence<2>{}, + Sequence<3>{}, + Sequence<4>{}, + Sequence<5>{}, + Sequence<6>{}, + Sequence<7>{})); + + const auto out_gemmk_gemmmraw_grid_desc = transform_tensor_descriptor( + out_n_zdotslice_dtildeslice_ydotslice_htildeslice_xdotslice_wtildeslice_k_grid_desc, + make_tuple( + make_merge_transform(make_tuple(ZDotSlice, YDotSlice, XDotSlice, K)), + make_merge_transform(make_tuple(N, DTildeSlice, HTildeSlice, WTildeSlice))), + make_tuple(Sequence<1, 3, 5, 7>{}, Sequence<0, 2, 4, 6>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + + const auto out_gemmk_gemmm_padded_grid_desc = + ck::tensor_operation::device::PadTensorDescriptor( + out_gemmk_gemmmraw_grid_desc, + make_tuple(GemmKPerBlock, GemmMPerBlock), + Sequence{}); + + const index_t AK0 = out_gemmk_gemmm_padded_grid_desc.GetLength(I0) / AK1; + + const auto out_gemmak0_gemmm_gemmak1_grid_desc = transform_tensor_descriptor( + out_gemmk_gemmm_padded_grid_desc, + make_tuple(make_unmerge_transform(make_tuple(AK0, AK1)), + make_pass_through_transform( + out_gemmk_gemmm_padded_grid_desc.GetLength(I1))), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + + return out_gemmak0_gemmm_gemmak1_grid_desc; + } + else + { + throw std::runtime_error("wrong! only implemented for 2D and 3D now"); + } } } template , + typename std::enable_if<(NDimSpatial == 2 || NDimSpatial == 3) && + (is_same_v || + is_same_v), bool>::type = false> static auto MakeBDescriptor_BK0_N_BK1( const std::array& out_g_n_k_wos_lengths, @@ -207,35 +500,40 @@ struct TransformConvBwdDataToGemm_v1 const std::array& /* input_right_pads */, const std::array& tildes) { - index_t i_ytilde = tildes[0]; - index_t i_xtilde = tildes[1]; + index_t i_ztilde = tildes[ZIdx - NonSpatialDimsNum]; + index_t i_ytilde = tildes[YIdx - NonSpatialDimsNum]; + index_t i_xtilde = tildes[XIdx - NonSpatialDimsNum]; const index_t N = in_g_n_c_wis_lengths[1]; const index_t K = wei_g_k_c_xs_lengths[1]; const index_t C = wei_g_k_c_xs_lengths[2]; - const index_t Ho = out_g_n_k_wos_lengths[3]; - const index_t Wo = out_g_n_k_wos_lengths[4]; + const index_t Do = NDimSpatial == 3 ? out_g_n_k_wos_lengths[DIdx] : 1; + const index_t Ho = out_g_n_k_wos_lengths[HIdx]; + const index_t Wo = out_g_n_k_wos_lengths[WIdx]; - const index_t Y = wei_g_k_c_xs_lengths[3]; - const index_t X = wei_g_k_c_xs_lengths[4]; + const index_t Z = NDimSpatial == 3 ? wei_g_k_c_xs_lengths[ZIdx] : 1; + const index_t Y = wei_g_k_c_xs_lengths[YIdx]; + const index_t X = wei_g_k_c_xs_lengths[XIdx]; - const index_t ConvStrideH = conv_filter_strides[0]; - const index_t ConvStrideW = conv_filter_strides[1]; + const index_t ConvStrideD = conv_filter_strides[DIdx - NonSpatialDimsNum]; + const index_t ConvStrideH = conv_filter_strides[HIdx - NonSpatialDimsNum]; + const index_t ConvStrideW = conv_filter_strides[WIdx - NonSpatialDimsNum]; - const index_t ConvDilationH = conv_filter_dilations[0]; - const index_t ConvDilationW = conv_filter_dilations[1]; - - const index_t BK0 = K / BK1; + const index_t ConvDilationD = conv_filter_dilations[DIdx - NonSpatialDimsNum]; + const index_t ConvDilationH = conv_filter_dilations[HIdx - NonSpatialDimsNum]; + const index_t ConvDilationW = conv_filter_dilations[WIdx - NonSpatialDimsNum]; // assume packed - const auto wei_k_y_x_c_grid_desc = - make_naive_tensor_descriptor_packed(make_tuple(K, Y, X, C)); + // k_y_x_c for 2d or k_z_y_x_c for 3d + const auto wei_grid_desc = make_wei_grid_desc(K, Z, Y, X, C); if constexpr(ConvBwdDataSpecialization == ck::tensor_operation::device::ConvolutionBackwardDataSpecialization:: Filter1x1Stride1Pad0) { + const index_t BK0 = math::integer_divide_ceil(K, BK1); + // B: weight tensor const auto wei_gemmbk0_gemmnraw_gemmbk1_grid_desc = transform_tensor_descriptor(make_naive_tensor_descriptor_packed(make_tuple(K, C)), @@ -243,7 +541,7 @@ struct TransformConvBwdDataToGemm_v1 make_pass_through_transform(C)), make_tuple(Sequence<0>{}, Sequence<1>{}), make_tuple(Sequence<0, 2>{}, Sequence<1>{})); - make_naive_tensor_descriptor(make_tuple(N * Ho * Wo, C), make_tuple(I0, I1)); + make_naive_tensor_descriptor(make_tuple(N * Do * Ho * Wo, C), make_tuple(I0, I1)); const auto wei_gemmbk0_gemmn_gemmbk1_grid_desc = ck::tensor_operation::device::PadTensorDescriptor( @@ -255,75 +553,175 @@ struct TransformConvBwdDataToGemm_v1 } else { + const auto GcdStrideDilationD = math::gcd(ConvStrideD, ConvDilationD); const auto GcdStrideDilationH = math::gcd(ConvStrideH, ConvDilationH); const auto GcdStrideDilationW = math::gcd(ConvStrideW, ConvDilationW); + const auto ZTilde = ConvStrideD / GcdStrideDilationD; const auto YTilde = ConvStrideH / GcdStrideDilationH; const auto XTilde = ConvStrideW / GcdStrideDilationW; + const auto ZDot = math::integer_divide_ceil(Z, ZTilde); const auto YDot = math::integer_divide_ceil(Y, YTilde); const auto XDot = math::integer_divide_ceil(X, XTilde); // GemmK is different for each GEMM + const auto ZDotSlice = math::integer_divide_ceil(Z - i_ztilde, ZTilde); const auto YDotSlice = math::integer_divide_ceil(Y - i_ytilde, YTilde); const auto XDotSlice = math::integer_divide_ceil(X - i_xtilde, XTilde); // B weight tensor - const auto wei_k_ydot_ytilde_xdot_xtilde_c_grid_desc = transform_tensor_descriptor( - wei_k_y_x_c_grid_desc, - make_tuple(make_pass_through_transform(K), - make_embed_transform(make_tuple(YDot, YTilde), - make_tuple(ConvStrideH / GcdStrideDilationH, I1)), - make_embed_transform(make_tuple(XDot, XTilde), - make_tuple(ConvStrideW / GcdStrideDilationW, I1)), - make_pass_through_transform(C)), - make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), - make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<3, 4>{}, Sequence<5>{})); - - const auto wei_bk0_bk1_ydotslice_xdotslice_c_grid_desc = - transform_tensor_descriptor(wei_k_ydot_ytilde_xdot_xtilde_c_grid_desc, - make_tuple(make_unmerge_transform(make_tuple(BK0, BK1)), - make_slice_transform(YDot, I0, YDotSlice), - make_slice_transform(XDot, I0, XDotSlice), - make_freeze_transform(i_ytilde), - make_freeze_transform(i_xtilde), - make_pass_through_transform(C)), - make_tuple(Sequence<0>{}, - Sequence<1>{}, - Sequence<3>{}, - Sequence<2>{}, - Sequence<4>{}, - Sequence<5>{}), - make_tuple(Sequence<0, 1>{}, - Sequence<2>{}, - Sequence<3>{}, - Sequence<>{}, - Sequence<>{}, - Sequence<4>{})); - - const auto wei_gemmbk0_gemmnraw_gemmbk1_grid_desc = transform_tensor_descriptor( - wei_bk0_bk1_ydotslice_xdotslice_c_grid_desc, - make_tuple(make_merge_transform(make_tuple(YDotSlice, XDotSlice, BK0)), - make_pass_through_transform(C), - make_pass_through_transform(BK1)), - make_tuple(Sequence<2, 3, 0>{}, Sequence<4>{}, Sequence<1>{}), - make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{})); - - const auto wei_gemmbk0_gemmn_gemmbk1_grid_desc = - ck::tensor_operation::device::PadTensorDescriptor( - wei_gemmbk0_gemmnraw_gemmbk1_grid_desc, + if constexpr(NDimSpatial == 2) + { + const auto wei_k_ydot_ytilde_xdot_xtilde_c_grid_desc = transform_tensor_descriptor( + wei_grid_desc, make_tuple( - wei_gemmbk0_gemmnraw_gemmbk1_grid_desc.GetLength(I0), GemmNPerBlock, BK1), - Sequence{}); - - return wei_gemmbk0_gemmn_gemmbk1_grid_desc; + make_pass_through_transform(K), + make_embed_transform(make_tuple(YDot, YTilde), + make_tuple(ConvStrideH / GcdStrideDilationH, I1)), + make_embed_transform(make_tuple(XDot, XTilde), + make_tuple(ConvStrideW / GcdStrideDilationW, I1)), + make_pass_through_transform(C)), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), + make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<3, 4>{}, Sequence<5>{})); + + const auto wei_k_ydotslice_xdotslice_c_grid_desc = transform_tensor_descriptor( + wei_k_ydot_ytilde_xdot_xtilde_c_grid_desc, + make_tuple(make_pass_through_transform(K), + make_slice_transform(YDot, I0, YDotSlice), + make_slice_transform(XDot, I0, XDotSlice), + make_freeze_transform(i_ytilde), + make_freeze_transform(i_xtilde), + make_pass_through_transform(C)), + make_tuple(Sequence<0>{}, + Sequence<1>{}, + Sequence<3>{}, + Sequence<2>{}, + Sequence<4>{}, + Sequence<5>{}), + make_tuple(Sequence<0>{}, + Sequence<1>{}, + Sequence<2>{}, + Sequence<>{}, + Sequence<>{}, + Sequence<3>{})); + + const auto wei_gemmk_gemmnraw_grid_desc = transform_tensor_descriptor( + wei_k_ydotslice_xdotslice_c_grid_desc, + make_tuple(make_merge_transform(make_tuple(YDotSlice, XDotSlice, K)), + make_pass_through_transform(C)), + make_tuple(Sequence<1, 2, 0>{}, Sequence<3>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + + const auto wei_gemmk_gemmn_padded_grid_desc = + ck::tensor_operation::device::PadTensorDescriptor( + wei_gemmk_gemmnraw_grid_desc, + make_tuple(GemmKPerBlock, GemmNPerBlock), + Sequence{}); + + const index_t BK0 = wei_gemmk_gemmn_padded_grid_desc.GetLength(I0) / BK1; + + const auto wei_gemmbk0_gemmn_gemmbk1_grid_desc = transform_tensor_descriptor( + wei_gemmk_gemmn_padded_grid_desc, + make_tuple(make_unmerge_transform(make_tuple(BK0, BK1)), + make_pass_through_transform( + wei_gemmk_gemmn_padded_grid_desc.GetLength(I1))), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + + return wei_gemmbk0_gemmn_gemmbk1_grid_desc; + } + else if constexpr(NDimSpatial == 3) + { + const auto wei_k_zdot_ztilde_ydot_ytilde_xdot_xtilde_c_grid_desc = + transform_tensor_descriptor( + wei_grid_desc, + make_tuple( + make_pass_through_transform(K), + make_embed_transform(make_tuple(ZDot, ZTilde), + make_tuple(ConvStrideD / GcdStrideDilationD, I1)), + make_embed_transform(make_tuple(YDot, YTilde), + make_tuple(ConvStrideH / GcdStrideDilationH, I1)), + make_embed_transform(make_tuple(XDot, XTilde), + make_tuple(ConvStrideW / GcdStrideDilationW, I1)), + make_pass_through_transform(C)), + make_tuple(Sequence<0>{}, + Sequence<1>{}, + Sequence<2>{}, + Sequence<3>{}, + Sequence<4>{}), + make_tuple(Sequence<0>{}, + Sequence<1, 2>{}, + Sequence<3, 4>{}, + Sequence<5, 6>{}, + Sequence<7>{})); + + const auto wei_gemmk_zdotslice_ydotslice_xdotslice_c_grid_desc = + transform_tensor_descriptor( + wei_k_zdot_ztilde_ydot_ytilde_xdot_xtilde_c_grid_desc, + make_tuple(make_pass_through_transform(K), + make_slice_transform(ZDot, I0, ZDotSlice), + make_slice_transform(YDot, I0, YDotSlice), + make_slice_transform(XDot, I0, XDotSlice), + make_freeze_transform(i_ztilde), + make_freeze_transform(i_ytilde), + make_freeze_transform(i_xtilde), + make_pass_through_transform(C)), + make_tuple(Sequence<0>{}, + Sequence<1>{}, + Sequence<3>{}, + Sequence<5>{}, + Sequence<2>{}, + Sequence<4>{}, + Sequence<6>{}, + Sequence<7>{}), + make_tuple(Sequence<0>{}, + Sequence<1>{}, + Sequence<2>{}, + Sequence<3>{}, + Sequence<>{}, + Sequence<>{}, + Sequence<>{}, + Sequence<4>{})); + + const auto wei_gemmk_gemmnraw_grid_desc = transform_tensor_descriptor( + wei_gemmk_zdotslice_ydotslice_xdotslice_c_grid_desc, + make_tuple(make_merge_transform(make_tuple(ZDotSlice, YDotSlice, XDotSlice, K)), + make_pass_through_transform(C)), + make_tuple(Sequence<1, 2, 3, 0>{}, Sequence<4>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + + const auto wei_gemmk_gemmn_padded_grid_desc = + ck::tensor_operation::device::PadTensorDescriptor( + wei_gemmk_gemmnraw_grid_desc, + make_tuple(GemmKPerBlock, GemmNPerBlock), + Sequence{}); + + const index_t BK0 = wei_gemmk_gemmn_padded_grid_desc.GetLength(I0) / BK1; + + const auto wei_gemmbk0_gemm_gemmbk1_grid_desc = transform_tensor_descriptor( + wei_gemmk_gemmn_padded_grid_desc, + make_tuple(make_unmerge_transform(make_tuple(BK0, BK1)), + make_pass_through_transform( + wei_gemmk_gemmn_padded_grid_desc.GetLength(I1))), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + + return wei_gemmbk0_gemm_gemmbk1_grid_desc; + } + else + { + throw std::runtime_error("wrong! only implemented for 2D and 3D now"); + } } } template || + is_same_v || is_same_v || + is_same_v || is_same_v), bool>::type = false> static auto @@ -339,153 +737,309 @@ struct TransformConvBwdDataToGemm_v1 const std::array& input_right_pads, const std::array& tildes) { - index_t i_ytilde = tildes[0]; - index_t i_xtilde = tildes[1]; + index_t i_ztilde = tildes[ZIdx - NonSpatialDimsNum]; + index_t i_ytilde = tildes[YIdx - NonSpatialDimsNum]; + index_t i_xtilde = tildes[XIdx - NonSpatialDimsNum]; const index_t N = in_g_n_c_wis_lengths[1]; const index_t C = wei_g_k_c_xs_lengths[2]; - const index_t Hi = in_g_n_c_wis_lengths[3]; - const index_t Wi = in_g_n_c_wis_lengths[4]; + const index_t Di = NDimSpatial == 3 ? in_g_n_c_wis_lengths[DIdx] : 1; + const index_t Hi = in_g_n_c_wis_lengths[HIdx]; + const index_t Wi = in_g_n_c_wis_lengths[WIdx]; - const index_t Ho = out_g_n_k_wos_lengths[3]; - const index_t Wo = out_g_n_k_wos_lengths[4]; + const index_t Do = NDimSpatial == 3 ? out_g_n_k_wos_lengths[DIdx] : 1; + const index_t Ho = out_g_n_k_wos_lengths[HIdx]; + const index_t Wo = out_g_n_k_wos_lengths[WIdx]; - const index_t Y = wei_g_k_c_xs_lengths[3]; - const index_t X = wei_g_k_c_xs_lengths[4]; + const index_t Z = NDimSpatial == 3 ? wei_g_k_c_xs_lengths[ZIdx] : 1; + const index_t Y = wei_g_k_c_xs_lengths[YIdx]; + const index_t X = wei_g_k_c_xs_lengths[XIdx]; - const index_t InLeftPadH = input_left_pads[0]; - const index_t InLeftPadW = input_left_pads[1]; + const index_t InLeftPadD = input_left_pads[DIdx - NonSpatialDimsNum]; + const index_t InLeftPadH = input_left_pads[HIdx - NonSpatialDimsNum]; + const index_t InLeftPadW = input_left_pads[WIdx - NonSpatialDimsNum]; - const index_t InRightPadH = input_right_pads[0]; - const index_t InRightPadW = input_right_pads[1]; + const index_t InRightPadD = input_right_pads[DIdx - NonSpatialDimsNum]; + const index_t InRightPadH = input_right_pads[HIdx - NonSpatialDimsNum]; + const index_t InRightPadW = input_right_pads[WIdx - NonSpatialDimsNum]; - const index_t ConvStrideH = conv_filter_strides[0]; - const index_t ConvStrideW = conv_filter_strides[1]; + const index_t ConvStrideD = conv_filter_strides[DIdx - NonSpatialDimsNum]; + const index_t ConvStrideH = conv_filter_strides[HIdx - NonSpatialDimsNum]; + const index_t ConvStrideW = conv_filter_strides[WIdx - NonSpatialDimsNum]; - const index_t ConvDilationH = conv_filter_dilations[0]; - const index_t ConvDilationW = conv_filter_dilations[1]; + const index_t ConvDilationD = conv_filter_dilations[DIdx - NonSpatialDimsNum]; + const index_t ConvDilationH = conv_filter_dilations[HIdx - NonSpatialDimsNum]; + const index_t ConvDilationW = conv_filter_dilations[WIdx - NonSpatialDimsNum]; // assume strided - const auto in_n_hi_wi_c_grid_desc = - make_naive_tensor_descriptor(make_tuple(N, Hi, Wi, C), - make_tuple(in_g_n_c_wis_strides[1], - in_g_n_c_wis_strides[3], - in_g_n_c_wis_strides[4], - in_g_n_c_wis_strides[2])); + // n_hi_wi_c for 2d n_di_hi_wi_c for 3d + const auto in_grid_desc = + make_in_grid_desc(N, Di, Hi, Wi, C, in_g_n_c_wis_strides); if constexpr(ConvBwdDataSpecialization == ck::tensor_operation::device::ConvolutionBackwardDataSpecialization:: Filter1x1Stride1Pad0) { // C: input tensor - const auto in_n_y_ho_x_wo_c_grid_desc = transform_tensor_descriptor( - in_n_hi_wi_c_grid_desc, - make_tuple(make_pass_through_transform(N), - make_embed_transform(make_tuple(I1, Ho), make_tuple(I1, ConvStrideH)), - make_embed_transform(make_tuple(I1, Wo), make_tuple(I1, ConvStrideW)), - make_pass_through_transform(C)), - make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), - make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<3, 4>{}, Sequence<5>{})); - - const auto in_gemmmraw_gemmnraw_grid_desc = transform_tensor_descriptor( - in_n_y_ho_x_wo_c_grid_desc, - make_tuple(make_freeze_transform(I0), - make_freeze_transform(I0), - make_merge_transform(make_tuple(N, Ho, Wo)), - make_pass_through_transform(C)), - make_tuple(Sequence<1>{}, Sequence<3>{}, Sequence<0, 2, 4>{}, Sequence<5>{}), - make_tuple(Sequence<>{}, Sequence<>{}, Sequence<0>{}, Sequence<1>{})); - - const auto in_gemmm_gemmn_grid_desc = ck::tensor_operation::device::PadTensorDescriptor( - in_gemmmraw_gemmnraw_grid_desc, - make_tuple(GemmMPerBlock, GemmNPerBlock), - Sequence{}); - - return in_gemmm_gemmn_grid_desc; + if constexpr(NDimSpatial == 2) + { + const auto in_n_y_ho_x_wo_c_grid_desc = transform_tensor_descriptor( + in_grid_desc, + make_tuple( + make_pass_through_transform(N), + make_embed_transform(make_tuple(I1, Ho), make_tuple(I1, ConvStrideH)), + make_embed_transform(make_tuple(I1, Wo), make_tuple(I1, ConvStrideW)), + make_pass_through_transform(C)), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), + make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<3, 4>{}, Sequence<5>{})); + + const auto in_gemmmraw_gemmnraw_grid_desc = transform_tensor_descriptor( + in_n_y_ho_x_wo_c_grid_desc, + make_tuple(make_freeze_transform(I0), + make_freeze_transform(I0), + make_merge_transform(make_tuple(N, Ho, Wo)), + make_pass_through_transform(C)), + make_tuple(Sequence<1>{}, Sequence<3>{}, Sequence<0, 2, 4>{}, Sequence<5>{}), + make_tuple(Sequence<>{}, Sequence<>{}, Sequence<0>{}, Sequence<1>{})); + + const auto in_gemmm_gemmn_grid_desc = + ck::tensor_operation::device::PadTensorDescriptor( + in_gemmmraw_gemmnraw_grid_desc, + make_tuple(GemmMPerBlock, GemmNPerBlock), + Sequence{}); + + return in_gemmm_gemmn_grid_desc; + } + else if constexpr(NDimSpatial == 3) + { + + // C: input tensor + const auto in_n_x_do_y_ho_x_wo_c_grid_desc = transform_tensor_descriptor( + in_grid_desc, + make_tuple( + make_pass_through_transform(N), + make_embed_transform(make_tuple(I1, Do), make_tuple(I1, ConvStrideD)), + make_embed_transform(make_tuple(I1, Ho), make_tuple(I1, ConvStrideH)), + make_embed_transform(make_tuple(I1, Wo), make_tuple(I1, ConvStrideW)), + make_pass_through_transform(C)), + make_tuple( + Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}, Sequence<4>{}), + make_tuple(Sequence<0>{}, + Sequence<1, 2>{}, + Sequence<3, 4>{}, + Sequence<5, 6>{}, + Sequence<7>{})); + + const auto in_gemmmraw_gemmnraw_grid_desc = transform_tensor_descriptor( + in_n_x_do_y_ho_x_wo_c_grid_desc, + make_tuple(make_freeze_transform(I0), + make_freeze_transform(I0), + make_freeze_transform(I0), + make_merge_transform(make_tuple(N, Do, Ho, Wo)), + make_pass_through_transform(C)), + make_tuple(Sequence<1>{}, + Sequence<3>{}, + Sequence<5>{}, + Sequence<0, 2, 4, 6>{}, + Sequence<7>{}), + make_tuple( + Sequence<>{}, Sequence<>{}, Sequence<>{}, Sequence<0>{}, Sequence<1>{})); + + const auto in_gemmm_gemmn_grid_desc = + ck::tensor_operation::device::PadTensorDescriptor( + in_gemmmraw_gemmnraw_grid_desc, + make_tuple(GemmMPerBlock, GemmNPerBlock), + Sequence{}); + + return in_gemmm_gemmn_grid_desc; + } + else + { + throw std::runtime_error("wrong! only implemented for 2D and 3D now"); + } } else { + const auto GcdStrideDilationD = math::gcd(ConvStrideD, ConvDilationD); const auto GcdStrideDilationH = math::gcd(ConvStrideH, ConvDilationH); const auto GcdStrideDilationW = math::gcd(ConvStrideW, ConvDilationW); + const auto ZTilde = ConvStrideD / GcdStrideDilationD; const auto YTilde = ConvStrideH / GcdStrideDilationH; const auto XTilde = ConvStrideW / GcdStrideDilationW; + const auto DTilde = + Do + math::integer_divide_ceil(ConvDilationD * (Z - I1), ConvStrideD); const auto HTilde = Ho + math::integer_divide_ceil(ConvDilationH * (Y - I1), ConvStrideH); const auto WTilde = Wo + math::integer_divide_ceil(ConvDilationW * (X - I1), ConvStrideW); - // only work on HTilde and WTilde that contribute to non-padding area of input tensor + // only work on DTilde, HTilde and WTilde that contribute to + // non-padding area of input tensor + const auto IDTildeSliceBegin = math::integer_divide_floor( + math::max(I0, InLeftPadD - ConvDilationD * (ZTilde - I1)), ConvStrideD); const auto IHTildeSliceBegin = math::integer_divide_floor( math::max(I0, InLeftPadH - ConvDilationH * (YTilde - I1)), ConvStrideH); const auto IWTildeSliceBegin = math::integer_divide_floor( math::max(I0, InLeftPadW - ConvDilationW * (XTilde - I1)), ConvStrideW); + const auto IDTildeSliceEnd = math::min( + DTilde, math::integer_divide_ceil(InLeftPadD + Di - I1, ConvStrideD) + I1); const auto IHTildeSliceEnd = math::min( HTilde, math::integer_divide_ceil(InLeftPadH + Hi - I1, ConvStrideH) + I1); const auto IWTildeSliceEnd = math::min( WTilde, math::integer_divide_ceil(InLeftPadW + Wi - I1, ConvStrideW) + I1); + const auto DTildeSlice = IDTildeSliceEnd - IDTildeSliceBegin; const auto HTildeSlice = IHTildeSliceEnd - IHTildeSliceBegin; const auto WTildeSlice = IWTildeSliceEnd - IWTildeSliceBegin; // C: input tensor - const auto in_n_hip_wip_c_grid_desc = transform_tensor_descriptor( - in_n_hi_wi_c_grid_desc, - make_tuple(make_pass_through_transform(N), - make_pad_transform(Hi, InLeftPadH, InRightPadH), - make_pad_transform(Wi, InLeftPadW, InRightPadW), - make_pass_through_transform(C)), - make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), - make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{})); - - const auto in_n_ytilde_htilde_xtilde_wtilde_c_grid_desc = transform_tensor_descriptor( - in_n_hip_wip_c_grid_desc, - make_tuple(make_pass_through_transform(N), - make_embed_transform(make_tuple(YTilde, HTilde), - make_tuple(ConvDilationH, ConvStrideH)), - make_embed_transform(make_tuple(XTilde, WTilde), - make_tuple(ConvDilationW, ConvStrideW)), - make_pass_through_transform(C)), - make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), - make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<3, 4>{}, Sequence<5>{})); - - const auto in_n_htildeslice_wtildeslice_c_grid_desc = transform_tensor_descriptor( - in_n_ytilde_htilde_xtilde_wtilde_c_grid_desc, - make_tuple(make_pass_through_transform(N), - make_freeze_transform(i_ytilde), - make_slice_transform(HTilde, IHTildeSliceBegin, HTildeSlice), - make_freeze_transform(i_xtilde), - make_slice_transform(WTilde, IWTildeSliceBegin, WTildeSlice), - make_pass_through_transform(C)), - make_tuple(Sequence<0>{}, - Sequence<1>{}, - Sequence<2>{}, - Sequence<3>{}, - Sequence<4>{}, - Sequence<5>{}), - make_tuple(Sequence<0>{}, - Sequence<>{}, - Sequence<1>{}, - Sequence<>{}, - Sequence<2>{}, - Sequence<3>{})); - - const auto in_gemmmraw_gemmnraw_grid_desc = transform_tensor_descriptor( - in_n_htildeslice_wtildeslice_c_grid_desc, - make_tuple(make_merge_transform(make_tuple(N, HTildeSlice, WTildeSlice)), - make_pass_through_transform(C)), - make_tuple(Sequence<0, 1, 2>{}, Sequence<3>{}), - make_tuple(Sequence<0>{}, Sequence<1>{})); - - const auto in_gemmm_gemmn_grid_desc = ck::tensor_operation::device::PadTensorDescriptor( - in_gemmmraw_gemmnraw_grid_desc, - make_tuple(GemmMPerBlock, GemmNPerBlock), - Sequence{}); - - return in_gemmm_gemmn_grid_desc; + if constexpr(NDimSpatial == 2) + { + const auto in_n_hip_wip_c_grid_desc = transform_tensor_descriptor( + in_grid_desc, + make_tuple(make_pass_through_transform(N), + make_pad_transform(Hi, InLeftPadH, InRightPadH), + make_pad_transform(Wi, InLeftPadW, InRightPadW), + make_pass_through_transform(C)), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{})); + + const auto in_n_ytilde_htilde_xtilde_wtilde_c_grid_desc = + transform_tensor_descriptor( + in_n_hip_wip_c_grid_desc, + make_tuple(make_pass_through_transform(N), + make_embed_transform(make_tuple(YTilde, HTilde), + make_tuple(ConvDilationH, ConvStrideH)), + make_embed_transform(make_tuple(XTilde, WTilde), + make_tuple(ConvDilationW, ConvStrideW)), + make_pass_through_transform(C)), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), + make_tuple( + Sequence<0>{}, Sequence<1, 2>{}, Sequence<3, 4>{}, Sequence<5>{})); + + const auto in_n_htildeslice_wtildeslice_c_grid_desc = transform_tensor_descriptor( + in_n_ytilde_htilde_xtilde_wtilde_c_grid_desc, + make_tuple(make_pass_through_transform(N), + make_freeze_transform(i_ytilde), + make_slice_transform(HTilde, IHTildeSliceBegin, HTildeSlice), + make_freeze_transform(i_xtilde), + make_slice_transform(WTilde, IWTildeSliceBegin, WTildeSlice), + make_pass_through_transform(C)), + make_tuple(Sequence<0>{}, + Sequence<1>{}, + Sequence<2>{}, + Sequence<3>{}, + Sequence<4>{}, + Sequence<5>{}), + make_tuple(Sequence<0>{}, + Sequence<>{}, + Sequence<1>{}, + Sequence<>{}, + Sequence<2>{}, + Sequence<3>{})); + + const auto in_gemmmraw_gemmnraw_grid_desc = transform_tensor_descriptor( + in_n_htildeslice_wtildeslice_c_grid_desc, + make_tuple(make_merge_transform(make_tuple(N, HTildeSlice, WTildeSlice)), + make_pass_through_transform(C)), + make_tuple(Sequence<0, 1, 2>{}, Sequence<3>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + + const auto in_gemmm_gemmn_grid_desc = + ck::tensor_operation::device::PadTensorDescriptor( + in_gemmmraw_gemmnraw_grid_desc, + make_tuple(GemmMPerBlock, GemmNPerBlock), + Sequence{}); + + return in_gemmm_gemmn_grid_desc; + } + else if(NDimSpatial == 3) + { + const auto in_n_dip_hip_wip_c_grid_desc = transform_tensor_descriptor( + in_grid_desc, + make_tuple(make_pass_through_transform(N), + make_pad_transform(Di, InLeftPadD, InRightPadD), + make_pad_transform(Hi, InLeftPadH, InRightPadH), + make_pad_transform(Wi, InLeftPadW, InRightPadW), + make_pass_through_transform(C)), + make_tuple( + Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}, Sequence<4>{}), + make_tuple( + Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}, Sequence<4>{})); + + const auto in_n_ztilde_dtilde_ytilde_htilde_xtilde_wtilde_c_grid_desc = + transform_tensor_descriptor( + in_n_dip_hip_wip_c_grid_desc, + make_tuple(make_pass_through_transform(N), + make_embed_transform(make_tuple(ZTilde, DTilde), + make_tuple(ConvDilationD, ConvStrideD)), + make_embed_transform(make_tuple(YTilde, HTilde), + make_tuple(ConvDilationH, ConvStrideH)), + make_embed_transform(make_tuple(XTilde, WTilde), + make_tuple(ConvDilationW, ConvStrideW)), + make_pass_through_transform(C)), + make_tuple(Sequence<0>{}, + Sequence<1>{}, + Sequence<2>{}, + Sequence<3>{}, + Sequence<4>{}), + make_tuple(Sequence<0>{}, + Sequence<1, 2>{}, + Sequence<3, 4>{}, + Sequence<5, 6>{}, + Sequence<7>{})); + + const auto in_n_dtildeslice_htildeslice_wtildeslice_c_grid_desc = + transform_tensor_descriptor( + in_n_ztilde_dtilde_ytilde_htilde_xtilde_wtilde_c_grid_desc, + make_tuple(make_pass_through_transform(N), + make_freeze_transform(i_ztilde), + make_slice_transform(DTilde, IDTildeSliceBegin, DTildeSlice), + make_freeze_transform(i_ytilde), + make_slice_transform(HTilde, IHTildeSliceBegin, HTildeSlice), + make_freeze_transform(i_xtilde), + make_slice_transform(WTilde, IWTildeSliceBegin, WTildeSlice), + make_pass_through_transform(C)), + make_tuple(Sequence<0>{}, + Sequence<1>{}, + Sequence<2>{}, + Sequence<3>{}, + Sequence<4>{}, + Sequence<5>{}, + Sequence<6>{}, + Sequence<7>{}), + make_tuple(Sequence<0>{}, + Sequence<>{}, + Sequence<1>{}, + Sequence<>{}, + Sequence<2>{}, + Sequence<>{}, + Sequence<3>{}, + Sequence<4>{})); + + const auto in_gemmmraw_gemmnraw_grid_desc = transform_tensor_descriptor( + in_n_dtildeslice_htildeslice_wtildeslice_c_grid_desc, + make_tuple( + make_merge_transform(make_tuple(N, DTildeSlice, HTildeSlice, WTildeSlice)), + make_pass_through_transform(C)), + make_tuple(Sequence<0, 1, 2, 3>{}, Sequence<4>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + + const auto in_gemmm_gemmn_grid_desc = + ck::tensor_operation::device::PadTensorDescriptor( + in_gemmmraw_gemmnraw_grid_desc, + make_tuple(GemmMPerBlock, GemmNPerBlock), + Sequence{}); + return in_gemmm_gemmn_grid_desc; + } + else + { + throw std::runtime_error("wrong! only implemented for 2D and 3D now"); + } } } diff --git a/include/ck/tensor_operation/operator_transform/transform_conv_fwd_to_gemm.hpp b/include/ck/tensor_operation/operator_transform/transform_conv_fwd_to_gemm.hpp index 1b5e64b66cf292fbcf4979d95aa4a5b1fa3e3a8b..cee3d2825b13facc912cdc8de21774ce793e9a2c 100644 --- a/include/ck/tensor_operation/operator_transform/transform_conv_fwd_to_gemm.hpp +++ b/include/ck/tensor_operation/operator_transform/transform_conv_fwd_to_gemm.hpp @@ -1,6 +1,6 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/include/ck/utility/amd_address_space.hpp b/include/ck/utility/amd_address_space.hpp index 9f1525914cdee16ee6cbed51297f0008d46f6085..d54f70e750e8032f706ae3f7f0dfec9f7cab4c8d 100644 --- a/include/ck/utility/amd_address_space.hpp +++ b/include/ck/utility/amd_address_space.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/include/ck/utility/amd_buffer_addressing.hpp b/include/ck/utility/amd_buffer_addressing.hpp index bdfb4f27580e011f0bebabfeae9146374a5ecd67..897cb4f249f415a1ca3a6660063ebc946641bd47 100644 --- a/include/ck/utility/amd_buffer_addressing.hpp +++ b/include/ck/utility/amd_buffer_addressing.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #pragma once #include "data_type.hpp" @@ -286,7 +286,22 @@ llvm_amdgcn_raw_buffer_atomic_max_fp64(double vdata, int soffset, // dst_wave_addr_offset int glc_slc) __asm("llvm.amdgcn.raw.buffer.atomic.fmax.f64"); -template +// memory coherency bit for buffer store/load instruction +// check ISA manual for each GFX target +// e.g. for +// https://www.amd.com/system/files/TechDocs/instinct-mi200-cdna2-instruction-set-architecture.pdf, +// page 67~68 +enum struct AmdBufferCoherenceEnum +{ + DefaultCoherence = 0, // default value + GLC = 1, + SLC = 2, + GLC_SLC = 3, +}; + +template __device__ typename vector_type::type amd_buffer_load_impl(int32x4_t src_wave_buffer_resource, index_t src_thread_addr_offset, index_t src_wave_addr_offset) @@ -305,28 +320,37 @@ __device__ typename vector_type::type amd_buffer_load_impl(int32x4_t src_w // use fp32 load to mimic fp64 load if constexpr(N == 1) { - const float2_t tmp = llvm_amdgcn_raw_buffer_load_fp32x2( - src_wave_buffer_resource, src_thread_addr_offset, src_wave_addr_offset, 0); + const float2_t tmp = + llvm_amdgcn_raw_buffer_load_fp32x2(src_wave_buffer_resource, + src_thread_addr_offset, + src_wave_addr_offset, + static_cast(coherence)); return bit_cast(tmp); } else if constexpr(N == 2) { - const float4_t tmp = llvm_amdgcn_raw_buffer_load_fp32x4( - src_wave_buffer_resource, src_thread_addr_offset, src_wave_addr_offset, 0); + const float4_t tmp = + llvm_amdgcn_raw_buffer_load_fp32x4(src_wave_buffer_resource, + src_thread_addr_offset, + src_wave_addr_offset, + static_cast(coherence)); return bit_cast(tmp); } else if constexpr(N == 4) { - const float4_t f32_0 = llvm_amdgcn_raw_buffer_load_fp32x4( - src_wave_buffer_resource, src_thread_addr_offset, src_wave_addr_offset, 0); + const float4_t f32_0 = + llvm_amdgcn_raw_buffer_load_fp32x4(src_wave_buffer_resource, + src_thread_addr_offset, + src_wave_addr_offset, + static_cast(coherence)); const float4_t f32_1 = llvm_amdgcn_raw_buffer_load_fp32x4(src_wave_buffer_resource, src_thread_addr_offset, src_wave_addr_offset + 4 * sizeof(float), - 0); + static_cast(coherence)); vector_type tmp; tmp.AsType()(Number<0>{}) = bit_cast(f32_0); @@ -339,31 +363,40 @@ __device__ typename vector_type::type amd_buffer_load_impl(int32x4_t src_w { if constexpr(N == 1) { - return llvm_amdgcn_raw_buffer_load_fp32( - src_wave_buffer_resource, src_thread_addr_offset, src_wave_addr_offset, 0); + return llvm_amdgcn_raw_buffer_load_fp32(src_wave_buffer_resource, + src_thread_addr_offset, + src_wave_addr_offset, + static_cast(coherence)); } else if constexpr(N == 2) { - return llvm_amdgcn_raw_buffer_load_fp32x2( - src_wave_buffer_resource, src_thread_addr_offset, src_wave_addr_offset, 0); + return llvm_amdgcn_raw_buffer_load_fp32x2(src_wave_buffer_resource, + src_thread_addr_offset, + src_wave_addr_offset, + static_cast(coherence)); } else if constexpr(N == 4) { - return llvm_amdgcn_raw_buffer_load_fp32x4( - src_wave_buffer_resource, src_thread_addr_offset, src_wave_addr_offset, 0); + return llvm_amdgcn_raw_buffer_load_fp32x4(src_wave_buffer_resource, + src_thread_addr_offset, + src_wave_addr_offset, + static_cast(coherence)); } else if constexpr(N == 8) { vector_type tmp; - tmp.AsType()(Number<0>{}) = llvm_amdgcn_raw_buffer_load_fp32x4( - src_wave_buffer_resource, src_thread_addr_offset, src_wave_addr_offset, 0); + tmp.AsType()(Number<0>{}) = + llvm_amdgcn_raw_buffer_load_fp32x4(src_wave_buffer_resource, + src_thread_addr_offset, + src_wave_addr_offset, + static_cast(coherence)); tmp.AsType()(Number<1>{}) = llvm_amdgcn_raw_buffer_load_fp32x4(src_wave_buffer_resource, src_thread_addr_offset, src_wave_addr_offset + 4 * sizeof(float), - 0); + static_cast(coherence)); return tmp.AsType()(Number<0>{}); } @@ -372,24 +405,32 @@ __device__ typename vector_type::type amd_buffer_load_impl(int32x4_t src_w { if constexpr(N == 1) { - return llvm_amdgcn_raw_buffer_load_fp16( - src_wave_buffer_resource, src_thread_addr_offset, src_wave_addr_offset, 0); + return llvm_amdgcn_raw_buffer_load_fp16(src_wave_buffer_resource, + src_thread_addr_offset, + src_wave_addr_offset, + static_cast(coherence)); } else if constexpr(N == 2) { - return llvm_amdgcn_raw_buffer_load_fp16x2( - src_wave_buffer_resource, src_thread_addr_offset, src_wave_addr_offset, 0); + return llvm_amdgcn_raw_buffer_load_fp16x2(src_wave_buffer_resource, + src_thread_addr_offset, + src_wave_addr_offset, + static_cast(coherence)); } else if constexpr(N == 4) { - return llvm_amdgcn_raw_buffer_load_fp16x4( - src_wave_buffer_resource, src_thread_addr_offset, src_wave_addr_offset, 0); + return llvm_amdgcn_raw_buffer_load_fp16x4(src_wave_buffer_resource, + src_thread_addr_offset, + src_wave_addr_offset, + static_cast(coherence)); } else if constexpr(N == 8) { // use fp32 load to mimic fp16 load - float4_t tmp = llvm_amdgcn_raw_buffer_load_fp32x4( - src_wave_buffer_resource, src_thread_addr_offset, src_wave_addr_offset, 0); + float4_t tmp = llvm_amdgcn_raw_buffer_load_fp32x4(src_wave_buffer_resource, + src_thread_addr_offset, + src_wave_addr_offset, + static_cast(coherence)); return bit_cast(tmp); } @@ -398,23 +439,31 @@ __device__ typename vector_type::type amd_buffer_load_impl(int32x4_t src_w { if constexpr(N == 1) { - return llvm_amdgcn_raw_buffer_load_i16( - src_wave_buffer_resource, src_thread_addr_offset, src_wave_addr_offset, 0); + return llvm_amdgcn_raw_buffer_load_i16(src_wave_buffer_resource, + src_thread_addr_offset, + src_wave_addr_offset, + static_cast(coherence)); } else if constexpr(N == 2) { - return llvm_amdgcn_raw_buffer_load_i16x2( - src_wave_buffer_resource, src_thread_addr_offset, src_wave_addr_offset, 0); + return llvm_amdgcn_raw_buffer_load_i16x2(src_wave_buffer_resource, + src_thread_addr_offset, + src_wave_addr_offset, + static_cast(coherence)); } else if constexpr(N == 4) { - return llvm_amdgcn_raw_buffer_load_i16x4( - src_wave_buffer_resource, src_thread_addr_offset, src_wave_addr_offset, 0); + return llvm_amdgcn_raw_buffer_load_i16x4(src_wave_buffer_resource, + src_thread_addr_offset, + src_wave_addr_offset, + static_cast(coherence)); } else if constexpr(N == 8) { - int32x4_t tmp = llvm_amdgcn_raw_buffer_load_i32x4( - src_wave_buffer_resource, src_thread_addr_offset, src_wave_addr_offset, 0); + int32x4_t tmp = llvm_amdgcn_raw_buffer_load_i32x4(src_wave_buffer_resource, + src_thread_addr_offset, + src_wave_addr_offset, + static_cast(coherence)); return bit_cast(tmp); } @@ -423,31 +472,40 @@ __device__ typename vector_type::type amd_buffer_load_impl(int32x4_t src_w { if constexpr(N == 1) { - return llvm_amdgcn_raw_buffer_load_i32( - src_wave_buffer_resource, src_thread_addr_offset, src_wave_addr_offset, 0); + return llvm_amdgcn_raw_buffer_load_i32(src_wave_buffer_resource, + src_thread_addr_offset, + src_wave_addr_offset, + static_cast(coherence)); } else if constexpr(N == 2) { - return llvm_amdgcn_raw_buffer_load_i32x2( - src_wave_buffer_resource, src_thread_addr_offset, src_wave_addr_offset, 0); + return llvm_amdgcn_raw_buffer_load_i32x2(src_wave_buffer_resource, + src_thread_addr_offset, + src_wave_addr_offset, + static_cast(coherence)); } else if constexpr(N == 4) { - return llvm_amdgcn_raw_buffer_load_i32x4( - src_wave_buffer_resource, src_thread_addr_offset, src_wave_addr_offset, 0); + return llvm_amdgcn_raw_buffer_load_i32x4(src_wave_buffer_resource, + src_thread_addr_offset, + src_wave_addr_offset, + static_cast(coherence)); } else if constexpr(N == 8) { vector_type tmp; - tmp.AsType()(Number<0>{}) = llvm_amdgcn_raw_buffer_load_i32x4( - src_wave_buffer_resource, src_thread_addr_offset, src_wave_addr_offset, 0); + tmp.AsType()(Number<0>{}) = + llvm_amdgcn_raw_buffer_load_i32x4(src_wave_buffer_resource, + src_thread_addr_offset, + src_wave_addr_offset, + static_cast(coherence)); tmp.AsType()(Number<1>{}) = llvm_amdgcn_raw_buffer_load_i32x4(src_wave_buffer_resource, src_thread_addr_offset, src_wave_addr_offset + 4 * sizeof(int32_t), - 0); + static_cast(coherence)); return tmp.AsType()(Number<0>{}); } } @@ -455,17 +513,23 @@ __device__ typename vector_type::type amd_buffer_load_impl(int32x4_t src_w { if constexpr(N == 1) { - return llvm_amdgcn_raw_buffer_load_i8( - src_wave_buffer_resource, src_thread_addr_offset, src_wave_addr_offset, 0); + return llvm_amdgcn_raw_buffer_load_i8(src_wave_buffer_resource, + src_thread_addr_offset, + src_wave_addr_offset, + static_cast(coherence)); } else if constexpr(N == 2) { #if !CK_WORKAROUND_SWDEV_XXXXXX_INT8_BUFFER_LOAD_STORE_ISSUE - return llvm_amdgcn_raw_buffer_load_i8x2( - src_wave_buffer_resource, src_thread_addr_offset, src_wave_addr_offset, 0); + return llvm_amdgcn_raw_buffer_load_i8x2(src_wave_buffer_resource, + src_thread_addr_offset, + src_wave_addr_offset, + static_cast(coherence)); #else - int16_t tmp = llvm_amdgcn_raw_buffer_load_i16( - src_wave_buffer_resource, src_thread_addr_offset, src_wave_addr_offset, 0); + int16_t tmp = llvm_amdgcn_raw_buffer_load_i16(src_wave_buffer_resource, + src_thread_addr_offset, + src_wave_addr_offset, + static_cast(coherence)); return bit_cast(tmp); #endif @@ -473,11 +537,15 @@ __device__ typename vector_type::type amd_buffer_load_impl(int32x4_t src_w else if constexpr(N == 4) { #if !CK_WORKAROUND_SWDEV_XXXXXX_INT8_BUFFER_LOAD_STORE_ISSUE - return llvm_amdgcn_raw_buffer_load_i8x4( - src_wave_buffer_resource, src_thread_addr_offset, src_wave_addr_offset, 0); + return llvm_amdgcn_raw_buffer_load_i8x4(src_wave_buffer_resource, + src_thread_addr_offset, + src_wave_addr_offset, + static_cast(coherence)); #else - int32_t tmp = llvm_amdgcn_raw_buffer_load_i32( - src_wave_buffer_resource, src_thread_addr_offset, src_wave_addr_offset, 0); + int32_t tmp = llvm_amdgcn_raw_buffer_load_i32(src_wave_buffer_resource, + src_thread_addr_offset, + src_wave_addr_offset, + static_cast(coherence)); return bit_cast(tmp); #endif @@ -487,19 +555,24 @@ __device__ typename vector_type::type amd_buffer_load_impl(int32x4_t src_w #if !CK_WORKAROUND_SWDEV_XXXXXX_INT8_BUFFER_LOAD_STORE_ISSUE vector_type tmp; - tmp.AsType()(Number<0>{}) = llvm_amdgcn_raw_buffer_load_i8x4( - src_wave_buffer_resource, src_thread_addr_offset, src_wave_addr_offset, 0); + tmp.AsType()(Number<0>{}) = + llvm_amdgcn_raw_buffer_load_i8x4(src_wave_buffer_resource, + src_thread_addr_offset, + src_wave_addr_offset, + static_cast(coherence)); tmp.AsType()(Number<1>{}) = llvm_amdgcn_raw_buffer_load_i8x4(src_wave_buffer_resource, src_thread_addr_offset, src_wave_addr_offset + 4 * sizeof(int8_t), - 0); + static_cast(coherence)); return tmp.AsType()(Number<0>{}); #else - int32x2_t tmp = llvm_amdgcn_raw_buffer_load_i32x2( - src_wave_buffer_resource, src_thread_addr_offset, src_wave_addr_offset, 0); + int32x2_t tmp = llvm_amdgcn_raw_buffer_load_i32x2(src_wave_buffer_resource, + src_thread_addr_offset, + src_wave_addr_offset, + static_cast(coherence)); return bit_cast(tmp); #endif @@ -509,31 +582,36 @@ __device__ typename vector_type::type amd_buffer_load_impl(int32x4_t src_w #if !CK_WORKAROUND_SWDEV_XXXXXX_INT8_BUFFER_LOAD_STORE_ISSUE vector_type tmp; - tmp.AsType()(Number<0>{}) = llvm_amdgcn_raw_buffer_load_i8x4( - src_wave_buffer_resource, src_thread_addr_offset, src_wave_addr_offset, 0); + tmp.AsType()(Number<0>{}) = + llvm_amdgcn_raw_buffer_load_i8x4(src_wave_buffer_resource, + src_thread_addr_offset, + src_wave_addr_offset, + static_cast(coherence)); tmp.AsType()(Number<1>{}) = llvm_amdgcn_raw_buffer_load_i8x4(src_wave_buffer_resource, src_thread_addr_offset, src_wave_addr_offset + 4 * sizeof(int8_t), - 0); + static_cast(coherence)); tmp.AsType()(Number<2>{}) = llvm_amdgcn_raw_buffer_load_i8x4(src_wave_buffer_resource, src_thread_addr_offset, src_wave_addr_offset + 8 * sizeof(int8_t), - 0); + static_cast(coherence)); tmp.AsType()(Number<3>{}) = llvm_amdgcn_raw_buffer_load_i8x4(src_wave_buffer_resource, src_thread_addr_offset, src_wave_addr_offset + 12 * sizeof(int8_t), - 0); + static_cast(coherence)); return tmp.AsType()(Number<0>{}); #else - int32x4_t tmp = llvm_amdgcn_raw_buffer_load_i32x4( - src_wave_buffer_resource, src_thread_addr_offset, src_wave_addr_offset, 0); + int32x4_t tmp = llvm_amdgcn_raw_buffer_load_i32x4(src_wave_buffer_resource, + src_thread_addr_offset, + src_wave_addr_offset, + static_cast(coherence)); return bit_cast(tmp); #endif @@ -541,7 +619,9 @@ __device__ typename vector_type::type amd_buffer_load_impl(int32x4_t src_w } } -template +template __device__ void amd_buffer_store_impl(const typename vector_type::type src_thread_data, int32x4_t dst_wave_buffer_resource, index_t dst_thread_addr_offset, @@ -549,7 +629,7 @@ __device__ void amd_buffer_store_impl(const typename vector_type::type src { static_assert( (is_same::value && (N == 1 || N == 2)) || - (is_same::value && (N == 1 || N == 2 || N == 4)) || + (is_same::value && (N == 1 || N == 2 || N == 4 || N == 8)) || (is_same::value && (N == 1 || N == 2 || N == 4 || N == 8)) || (is_same::value && (N == 1 || N == 2 || N == 4 || N == 8)) || (is_same::value && (N == 1 || N == 2 || N == 4)) || @@ -565,7 +645,7 @@ __device__ void amd_buffer_store_impl(const typename vector_type::type src dst_wave_buffer_resource, dst_thread_addr_offset, dst_wave_addr_offset, - 0); + static_cast(coherence)); } else if constexpr(N == 2) { @@ -573,7 +653,7 @@ __device__ void amd_buffer_store_impl(const typename vector_type::type src dst_wave_buffer_resource, dst_thread_addr_offset, dst_wave_addr_offset, - 0); + static_cast(coherence)); } } else if constexpr(is_same::value) @@ -584,7 +664,7 @@ __device__ void amd_buffer_store_impl(const typename vector_type::type src dst_wave_buffer_resource, dst_thread_addr_offset, dst_wave_addr_offset, - 0); + static_cast(coherence)); } else if constexpr(N == 2) { @@ -592,7 +672,7 @@ __device__ void amd_buffer_store_impl(const typename vector_type::type src dst_wave_buffer_resource, dst_thread_addr_offset, dst_wave_addr_offset, - 0); + static_cast(coherence)); } else if constexpr(N == 4) { @@ -600,7 +680,21 @@ __device__ void amd_buffer_store_impl(const typename vector_type::type src dst_wave_buffer_resource, dst_thread_addr_offset, dst_wave_addr_offset, - 0); + static_cast(coherence)); + } + else if constexpr(N == 8) + { + vector_type tmp{src_thread_data}; + llvm_amdgcn_raw_buffer_store_fp32x4(tmp.AsType()[Number<0>{}], + dst_wave_buffer_resource, + dst_thread_addr_offset, + dst_wave_addr_offset, + static_cast(coherence)); + llvm_amdgcn_raw_buffer_store_fp32x4(tmp.AsType()[Number<1>{}], + dst_wave_buffer_resource, + dst_thread_addr_offset, + dst_wave_addr_offset + 4 * sizeof(float), + static_cast(coherence)); } } else if constexpr(is_same::value) @@ -611,7 +705,7 @@ __device__ void amd_buffer_store_impl(const typename vector_type::type src dst_wave_buffer_resource, dst_thread_addr_offset, dst_wave_addr_offset, - 0); + static_cast(coherence)); } else if constexpr(N == 2) { @@ -619,7 +713,7 @@ __device__ void amd_buffer_store_impl(const typename vector_type::type src dst_wave_buffer_resource, dst_thread_addr_offset, dst_wave_addr_offset, - 0); + static_cast(coherence)); } else if constexpr(N == 4) { @@ -627,7 +721,7 @@ __device__ void amd_buffer_store_impl(const typename vector_type::type src dst_wave_buffer_resource, dst_thread_addr_offset, dst_wave_addr_offset, - 0); + static_cast(coherence)); } else if constexpr(N == 8) { @@ -638,19 +732,19 @@ __device__ void amd_buffer_store_impl(const typename vector_type::type src dst_wave_buffer_resource, dst_thread_addr_offset, dst_wave_addr_offset, - 0); + static_cast(coherence)); llvm_amdgcn_raw_buffer_store_fp16x4(tmp.AsType()[Number<1>{}], dst_wave_buffer_resource, dst_thread_addr_offset, dst_wave_addr_offset + 4 * sizeof(half_t), - 0); + static_cast(coherence)); #else llvm_amdgcn_raw_buffer_store_fp32x4(bit_cast(src_thread_data), dst_wave_buffer_resource, dst_thread_addr_offset, dst_wave_addr_offset, - 0); + static_cast(coherence)); #endif } } @@ -662,7 +756,7 @@ __device__ void amd_buffer_store_impl(const typename vector_type::type src dst_wave_buffer_resource, dst_thread_addr_offset, dst_wave_addr_offset, - 0); + static_cast(coherence)); } else if constexpr(N == 2) { @@ -670,7 +764,7 @@ __device__ void amd_buffer_store_impl(const typename vector_type::type src dst_wave_buffer_resource, dst_thread_addr_offset, dst_wave_addr_offset, - 0); + static_cast(coherence)); } else if constexpr(N == 4) { @@ -678,7 +772,7 @@ __device__ void amd_buffer_store_impl(const typename vector_type::type src dst_wave_buffer_resource, dst_thread_addr_offset, dst_wave_addr_offset, - 0); + static_cast(coherence)); } else if constexpr(N == 8) { @@ -688,13 +782,13 @@ __device__ void amd_buffer_store_impl(const typename vector_type::type src dst_wave_buffer_resource, dst_thread_addr_offset, dst_wave_addr_offset, - 0); + static_cast(coherence)); llvm_amdgcn_raw_buffer_store_i16x4(tmp.AsType()[Number<1>{}], dst_wave_buffer_resource, dst_thread_addr_offset, dst_wave_addr_offset + 4 * sizeof(bhalf_t), - 0); + static_cast(coherence)); } } else if constexpr(is_same::value) @@ -705,7 +799,7 @@ __device__ void amd_buffer_store_impl(const typename vector_type::type src dst_wave_buffer_resource, dst_thread_addr_offset, dst_wave_addr_offset, - 0); + static_cast(coherence)); } else if constexpr(N == 2) { @@ -713,7 +807,7 @@ __device__ void amd_buffer_store_impl(const typename vector_type::type src dst_wave_buffer_resource, dst_thread_addr_offset, dst_wave_addr_offset, - 0); + static_cast(coherence)); } else if constexpr(N == 4) { @@ -721,7 +815,7 @@ __device__ void amd_buffer_store_impl(const typename vector_type::type src dst_wave_buffer_resource, dst_thread_addr_offset, dst_wave_addr_offset, - 0); + static_cast(coherence)); } } else if constexpr(is_same::value) @@ -732,7 +826,7 @@ __device__ void amd_buffer_store_impl(const typename vector_type::type src dst_wave_buffer_resource, dst_thread_addr_offset, dst_wave_addr_offset, - 0); + static_cast(coherence)); } else if constexpr(N == 2) { @@ -741,13 +835,13 @@ __device__ void amd_buffer_store_impl(const typename vector_type::type src dst_wave_buffer_resource, dst_thread_addr_offset, dst_wave_addr_offset, - 0); + static_cast(coherence)); #else llvm_amdgcn_raw_buffer_store_i16(bit_cast(src_thread_data), dst_wave_buffer_resource, dst_thread_addr_offset, dst_wave_addr_offset, - 0); + static_cast(coherence)); #endif } else if constexpr(N == 4) @@ -757,13 +851,13 @@ __device__ void amd_buffer_store_impl(const typename vector_type::type src dst_wave_buffer_resource, dst_thread_addr_offset, dst_wave_addr_offset, - 0); + static_cast(coherence)); #else llvm_amdgcn_raw_buffer_store_i32(bit_cast(src_thread_data), dst_wave_buffer_resource, dst_thread_addr_offset, dst_wave_addr_offset, - 0); + static_cast(coherence)); #endif } else if constexpr(N == 8) @@ -772,7 +866,7 @@ __device__ void amd_buffer_store_impl(const typename vector_type::type src dst_wave_buffer_resource, dst_thread_addr_offset, dst_wave_addr_offset, - 0); + static_cast(coherence)); } else if constexpr(N == 16) { @@ -780,7 +874,7 @@ __device__ void amd_buffer_store_impl(const typename vector_type::type src dst_wave_buffer_resource, dst_thread_addr_offset, dst_wave_addr_offset, - 0); + static_cast(coherence)); } } } @@ -1012,7 +1106,9 @@ __device__ void amd_buffer_atomic_max_impl(const typename vector_type::typ // 1) p_src_wave must point to global memory space // 2) p_src_wave must be a wavewise pointer. // It is user's responsibility to make sure that is true. -template +template __device__ typename vector_type_maker::type::type amd_buffer_load_invalid_element_return_zero(const T* p_src_wave, index_t src_thread_element_offset, @@ -1032,13 +1128,30 @@ amd_buffer_load_invalid_element_return_zero(const T* p_src_wave, #if CK_EXPERIMENTAL_USE_BUFFER_LOAD_OOB_CHECK_OFFSET_TRICK uint32_t src_addr_shift = src_thread_element_valid ? 0 : 0x80000000; - return amd_buffer_load_impl( - src_wave_buffer_resource, src_addr_shift + src_thread_addr_offset, 0); + if constexpr(is_same::value) + { + auto tmp = amd_buffer_load_impl( + src_wave_buffer_resource, src_addr_shift + src_thread_addr_offset, 0); + return bit_cast(tmp); + } + else + { + return amd_buffer_load_impl( + src_wave_buffer_resource, src_addr_shift + src_thread_addr_offset, 0); + } #else - vector_t tmp = amd_buffer_load_impl( - src_wave_buffer_resource, src_thread_addr_offset, 0); - - return src_thread_element_valid ? tmp : vector_t(0); + if constexpr(is_same::value) + { + auto tmp = amd_buffer_load_impl( + src_wave_buffer_resource, src_thread_addr_offset, 0); + return src_thread_element_valid ? bit_cast(tmp) : vector_t(0); + } + else + { + vector_t tmp = amd_buffer_load_impl( + src_wave_buffer_resource, src_thread_addr_offset, 0); + return src_thread_element_valid ? tmp : vector_t(0); + } #endif } @@ -1046,7 +1159,9 @@ amd_buffer_load_invalid_element_return_zero(const T* p_src_wave, // 1) p_src_wave must point to global memory space // 2) p_src_wave must be a wavewise pointer. // It is user's responsibility to make sure that is true. -template +template __device__ typename vector_type_maker::type::type amd_buffer_load_invalid_element_return_customized_value(const T* p_src_wave, index_t src_thread_element_offset, @@ -1064,7 +1179,7 @@ amd_buffer_load_invalid_element_return_customized_value(const T* p_src_wave, constexpr index_t vector_size = scalar_type::vector_size; - vector_t tmp = amd_buffer_load_impl( + vector_t tmp = amd_buffer_load_impl( src_wave_buffer_resource, src_thread_addr_offset, 0); return src_thread_element_valid ? tmp : vector_t(customized_value); @@ -1074,7 +1189,9 @@ amd_buffer_load_invalid_element_return_customized_value(const T* p_src_wave, // 1) p_dst_wave must point to global memory // 2) p_dst_wave must be a wavewise pointer. // It is user's responsibility to make sure that is true. -template +template __device__ void amd_buffer_store(const typename vector_type_maker::type::type src_thread_data, T* p_dst_wave, const index_t dst_thread_element_offset, @@ -1093,13 +1210,33 @@ __device__ void amd_buffer_store(const typename vector_type_maker::type::t #if CK_EXPERIMENTAL_USE_BUFFER_STORE_OOB_CHECK_OFFSET_TRICK uint32_t dst_addr_shift = dst_thread_element_valid ? 0 : 0x80000000; - amd_buffer_store_impl( - src_thread_data, dst_wave_buffer_resource, dst_addr_shift + dst_thread_addr_offset, 0); + if constexpr(is_same::value) + { + auto tmp = + bit_cast::type::type>(src_thread_data); + amd_buffer_store_impl( + tmp, dst_wave_buffer_resource, dst_addr_shift + dst_thread_addr_offset, 0); + } + else + { + amd_buffer_store_impl( + src_thread_data, dst_wave_buffer_resource, dst_addr_shift + dst_thread_addr_offset, 0); + } #else if(dst_thread_element_valid) { - amd_buffer_store_impl( - src_thread_data, dst_wave_buffer_resource, dst_thread_addr_offset, 0); + if constexpr(is_same::value) + { + auto tmp = bit_cast::type::type>( + src_thread_data); + amd_buffer_store_impl( + tmp, dst_wave_buffer_resource, dst_thread_addr_offset, 0); + } + else + { + amd_buffer_store_impl( + src_thread_data, dst_wave_buffer_resource, dst_thread_addr_offset, 0); + } } #endif } diff --git a/include/ck/utility/amd_gemm_dpp.hpp b/include/ck/utility/amd_gemm_dpp.hpp new file mode 100644 index 0000000000000000000000000000000000000000..a28292dade3dcad8ddcd5bd3c0cd119aeb4b4253 --- /dev/null +++ b/include/ck/utility/amd_gemm_dpp.hpp @@ -0,0 +1,68 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck/utility/common_header.hpp" +#include "ck/utility/math.hpp" +#include "ck/utility/inner_product_dpp8.hpp" + +namespace ck { + +namespace dpp8 { + +template +struct dpp_datatypes; + +template <> +struct dpp_datatypes +{ + // Dot product of `half2_t` and `half2_t` to get `float`. Reducing 2 elements from K in a + // single instruction. + using a_dtype = half_t; + using b_dtype = half_t; + using c_dtype = float; + static constexpr index_t k_per_instr = 2; +}; + +template +struct DppLanegroupGemm +{ + using datatypes_conf = dpp_datatypes; + using ADataType = typename datatypes_conf::a_dtype; + using BDataType = typename datatypes_conf::b_dtype; + using CDataType = typename datatypes_conf::c_dtype; + + __device__ void Run(const AVecDataType& a_vec, const BVecDataType& b_vec, CVecDataType& c_vec) + { + constexpr index_t num_c_elems_per_thread = ShareA ? MPerThread : NPerThread; + + const vector_type a_vector{a_vec}; + const vector_type b_vector{b_vec}; + + static_for<0, num_c_elems_per_thread, 1>{}([&](auto c_idx) { + float c = c_vec.template AsType()(c_idx); + // Next `c_idx` implies that we need to pull data from the next lane. + constexpr index_t source_lane = c_idx; + static_for<0, KPerThread / datatypes_conf::k_per_instr, 1>{}([&](auto k_chunk) { + const auto a_k_vec = a_vector.template AsType()[k_chunk]; + const auto b_k_vec = b_vector.template AsType()[k_chunk]; + ck::dpp8:: + inner_product_dpp( + a_k_vec, b_k_vec, c); + }); + c_vec.template AsType()(c_idx) = c; + }); + } +}; + +} // namespace dpp8 + +} // namespace ck diff --git a/include/ck/utility/amd_inline_asm.hpp b/include/ck/utility/amd_inline_asm.hpp index 1f7df70bcd565dcba4d0aee3a49aca32f835bb45..43baa817d3672c0b2eed6059a88351aeb19a5209 100644 --- a/include/ck/utility/amd_inline_asm.hpp +++ b/include/ck/utility/amd_inline_asm.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #ifndef CK_AMD_INLINE_ASM_HPP #define CK_AMD_INLINE_ASM_HPP diff --git a/include/ck/utility/amd_wave_read_first_lane.hpp b/include/ck/utility/amd_wave_read_first_lane.hpp index 4652ce7a747dab404fad36c37439b50f1346381c..741b2975af6c5bf99346b1460018eac6fa33b21b 100644 --- a/include/ck/utility/amd_wave_read_first_lane.hpp +++ b/include/ck/utility/amd_wave_read_first_lane.hpp @@ -7,6 +7,7 @@ #include "ck/utility/functional2.hpp" #include "ck/utility/math.hpp" +#include #include #include #include @@ -14,29 +15,83 @@ namespace ck { namespace detail { -template -struct get_unsigned_int; +template +struct get_carrier; template <> -struct get_unsigned_int<1> +struct get_carrier<1> { using type = uint8_t; }; template <> -struct get_unsigned_int<2> +struct get_carrier<2> { using type = uint16_t; }; template <> -struct get_unsigned_int<4> +struct get_carrier<3> +{ + using type = class carrier + { + using value_type = uint32_t; + + std::array bytes; + static_assert(sizeof(bytes) <= sizeof(value_type)); + + // replacement of host std::copy_n() + template + __device__ static OutputIterator copy_n(InputIterator from, Size size, OutputIterator to) + { + if(0 < size) + { + *to = *from; + ++to; + for(Size count = 1; count < size; ++count) + { + *to = *++from; + ++to; + } + } + + return to; + } + + // method to trigger template substitution failure + __device__ carrier(const carrier& other) noexcept + { + copy_n(other.bytes.begin(), bytes.size(), bytes.begin()); + } + + public: + __device__ carrier& operator=(value_type value) noexcept + { + copy_n(reinterpret_cast(&value), bytes.size(), bytes.begin()); + + return *this; + } + + __device__ operator value_type() const noexcept + { + std::byte result[sizeof(value_type)]; + + copy_n(bytes.begin(), bytes.size(), result); + + return *reinterpret_cast(result); + } + }; +}; +static_assert(sizeof(get_carrier<3>::type) == 3); + +template <> +struct get_carrier<4> { using type = uint32_t; }; -template -using get_unsigned_int_t = typename get_unsigned_int::type; +template +using get_carrier_t = typename get_carrier::type; } // namespace detail @@ -61,7 +116,7 @@ __device__ auto amd_wave_read_first_lane(const Object& obj) constexpr Size CompleteSgprCopyBoundary = ObjectSize - RemainedSize; for(Size offset = 0; offset < CompleteSgprCopyBoundary; offset += SgprSize) { - using Sgpr = detail::get_unsigned_int_t; + using Sgpr = detail::get_carrier_t; *reinterpret_cast(to_obj + offset) = amd_wave_read_first_lane(*reinterpret_cast(from_obj + offset)); @@ -69,9 +124,9 @@ __device__ auto amd_wave_read_first_lane(const Object& obj) if constexpr(0 < RemainedSize) { - using Carrier = detail::get_unsigned_int_t; + using Carrier = detail::get_carrier_t; - *reinterpret_cast(to_obj + CompleteSgprCopyBoundary) = amd_wave_read_first_lane( + *reinterpret_cast(to_obj + CompleteSgprCopyBoundary) = amd_wave_read_first_lane( *reinterpret_cast(from_obj + CompleteSgprCopyBoundary)); } diff --git a/include/ck/utility/amd_wmma.hpp b/include/ck/utility/amd_wmma.hpp index bf091425485d75a244399f4781e1d44c4b38dd2c..dd7f0b770a11470037b1d33164d830cb9819e8d1 100644 --- a/include/ck/utility/amd_wmma.hpp +++ b/include/ck/utility/amd_wmma.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #ifndef CK_AMD_WMMA_HPP #define CK_AMD_WMMA_HPP diff --git a/include/ck/utility/amd_xdlops.hpp b/include/ck/utility/amd_xdlops.hpp index a742496fc1f2fcf2ef529bbf079c481bc8d99f03..ea7755036fc64b0af6aa340e062e1277c485e785 100644 --- a/include/ck/utility/amd_xdlops.hpp +++ b/include/ck/utility/amd_xdlops.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #ifndef CK_AMD_XDLOPS_HPP #define CK_AMD_XDLOPS_HPP @@ -344,7 +344,7 @@ struct intrin_mfma_f64_16x16x4f64<16, 16> template __device__ static void Run(const double& reg_a, const double& reg_b, FloatC& reg_c) { -#if defined(__gfx90a__) || defined(__gfx940__) +#if defined(__gfx90a__) || defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__) reg_c.template AsType()(Number<0>{}) = __builtin_amdgcn_mfma_f64_16x16x4f64( reg_a, reg_b, reg_c.template AsType()[Number<0>{}], 0, 0, 0); #else @@ -354,5 +354,68 @@ struct intrin_mfma_f64_16x16x4f64<16, 16> #endif } }; + +template +struct intrin_mfma_f32_32x32x16f8f8; + +template <> +struct intrin_mfma_f32_32x32x16f8f8<32, 32> +{ + template + __device__ static void Run(const f8x8_t& reg_a, const f8x8_t& reg_b, FloatC& reg_c) + { +#if defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__) + reg_c.template AsType()(Number<0>{}) = + __builtin_amdgcn_mfma_f32_32x32x16_fp8_fp8( + bit_cast(reg_a), + bit_cast(reg_b), + reg_c.template AsType()[Number<0>{}], + 0, + 0, + 0); +#else + vector_type reg_a_v(reg_a); + vector_type reg_b_v(reg_b); + + static_for<0, 8, 1>{}([&](auto k) { + float reg_a_f32 = type_convert(reg_a_v.template AsType()[Number{}]); + float reg_b_f32 = type_convert(reg_b_v.template AsType()[Number{}]); + + intrin_mfma_f32_32x32x2f32<32, 32>::Run(reg_a_f32, reg_b_f32, reg_c); + }); +#endif + } +}; + +template +struct intrin_mfma_f32_16x16x32f8f8; + +template <> +struct intrin_mfma_f32_16x16x32f8f8<16, 16> +{ + template + __device__ static void Run(const f8x8_t& reg_a, const f8x8_t& reg_b, FloatC& reg_c) + { +#if defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__) + reg_c.template AsType()(Number<0>{}) = __builtin_amdgcn_mfma_f32_16x16x32_fp8_fp8( + bit_cast(reg_a), + bit_cast(reg_b), + reg_c.template AsType()[Number<0>{}], + 0, + 0, + 0); +#else + vector_type reg_a_v(reg_a); + vector_type reg_b_v(reg_b); + + static_for<0, 8, 1>{}([&](auto k) { + float reg_a_f32 = type_convert(reg_a_v.template AsType()[Number{}]); + float reg_b_f32 = type_convert(reg_b_v.template AsType()[Number{}]); + + intrin_mfma_f32_16x16x4f32<16, 16>::Run(reg_a_f32, reg_b_f32, reg_c); + }); +#endif + } +}; } // namespace ck #endif diff --git a/include/ck/utility/array.hpp b/include/ck/utility/array.hpp index 370a457fe9d97a621e6f91e78f14733398ebf756..f63ce5e5a07a796888cb60ae8da0c855df75e7ff 100644 --- a/include/ck/utility/array.hpp +++ b/include/ck/utility/array.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #ifndef CK_ARRAY_HPP #define CK_ARRAY_HPP diff --git a/include/ck/utility/array_multi_index.hpp b/include/ck/utility/array_multi_index.hpp index 9b8d5b95e9f6241b574104313b5b8fa2984ab18d..c0c1ea65fc77dd625aaaf6456a6628e09df27b51 100644 --- a/include/ck/utility/array_multi_index.hpp +++ b/include/ck/utility/array_multi_index.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #ifndef CK_ARRAY_MULTI_INDEX_HPP #define CK_ARRAY_MULTI_INDEX_HPP diff --git a/include/ck/utility/c_style_pointer_cast.hpp b/include/ck/utility/c_style_pointer_cast.hpp index 6e8b0081587afcb2db159a05ecd5fb940def68ff..610e393a77216500448c3682b5db9b9860a077da 100644 --- a/include/ck/utility/c_style_pointer_cast.hpp +++ b/include/ck/utility/c_style_pointer_cast.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #ifndef CK_C_STYLE_POINTER_CAST_HPP #define CK_C_STYLE_POINTER_CAST_HPP diff --git a/include/ck/utility/common_header.hpp b/include/ck/utility/common_header.hpp index 8da87c8766ea6c53e2eb60b98a53d4b30d525fcd..f95660a8a47d2c93a65a497a22a74c4c0b6eaea4 100644 --- a/include/ck/utility/common_header.hpp +++ b/include/ck/utility/common_header.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #pragma once @@ -24,6 +24,7 @@ #include "ck/utility/tuple.hpp" #include "ck/utility/tuple_helper.hpp" #include "ck/utility/type.hpp" +#include "ck/utility/type_convert.hpp" #include "ck/utility/magic_division.hpp" #include "ck/utility/c_style_pointer_cast.hpp" #include "ck/utility/is_known_at_compile_time.hpp" diff --git a/include/ck/utility/container_element_picker.hpp b/include/ck/utility/container_element_picker.hpp index abc5185e04a5079edb95fe15bfcd6788256d384a..838147e420cc7c95b63273b32ae3b2be2429c017 100644 --- a/include/ck/utility/container_element_picker.hpp +++ b/include/ck/utility/container_element_picker.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #ifndef CK_CONTAINER_ELEMENT_PICKER_HPP #define CK_CONTAINER_ELEMENT_PICKER_HPP diff --git a/include/ck/utility/container_helper.hpp b/include/ck/utility/container_helper.hpp index c8b02bc5acaf8542bf028fff39c42061db3b3296..9c7b954565d386a8fdecd21052b102e750ab7102 100644 --- a/include/ck/utility/container_helper.hpp +++ b/include/ck/utility/container_helper.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #ifndef CK_CONTAINER_HELPER_HPP #define CK_CONTAINER_HELPER_HPP diff --git a/include/ck/utility/data_type.hpp b/include/ck/utility/data_type.hpp index d43af8a2e393574bf2b56ca311387717b5b1fe43..c240afa2b867951059d59dd3e7e267edd816382e 100644 --- a/include/ck/utility/data_type.hpp +++ b/include/ck/utility/data_type.hpp @@ -12,6 +12,7 @@ using half_t = _Float16; #ifdef CK_EXPERIMENTAL_BIT_INT_EXTENSION_INT4 using int4_t = _BitInt(4); #endif +using f8_t = uint8_t; // vector_type template @@ -142,6 +143,13 @@ struct scalar_type }; #endif +template <> +struct scalar_type +{ + using type = f8_t; + static constexpr index_t vector_size = 1; +}; + // template struct vector_type @@ -944,151 +952,13 @@ using int8x16_t = typename vector_type::type; using int8x32_t = typename vector_type::type; using int8x64_t = typename vector_type::type; -// Convert X to Y -template -__host__ __device__ constexpr Y type_convert(X x) -{ - static_assert(!std::is_reference_v && !std::is_reference_v); - - return static_cast(x); -} - -// convert bfp16 to fp32 -template <> -inline __host__ __device__ constexpr float type_convert(bhalf_t x) -{ - union - { - uint32_t int32; - float fp32; - } u = {uint32_t(x) << 16}; - - return u.fp32; -} - -// convert fp32 to bfp16 -template <> -inline __host__ __device__ constexpr bhalf_t type_convert(float x) -{ - union - { - float fp32; - uint32_t int32; - } u = {x}; - - return uint16_t(u.int32 >> 16); -} - -// convert bfp16 to fp16 via fp32 -template <> -inline __host__ __device__ constexpr half_t type_convert(bhalf_t x) -{ - float x_fp32 = type_convert(x); - - return static_cast(x_fp32); -} - -// convert fp16 to bfp16 via fp32 -template <> -inline __host__ __device__ constexpr bhalf_t type_convert(half_t x) -{ - float x_fp32 = static_cast(x); - - return type_convert(x_fp32); -} - -// convert bfp16 to int32 via fp32 -template <> -inline __host__ __device__ constexpr int32_t type_convert(bhalf_t x) -{ - float x_fp32 = type_convert(x); - - return static_cast(x_fp32); -} - -// convert int32 to bfp16 via fp32 -template <> -inline __host__ __device__ constexpr bhalf_t type_convert(int32_t x) -{ - float x_fp32 = static_cast(x); - - return type_convert(x_fp32); -} - -// convert bfp16 to int8 via fp32 -template <> -inline __host__ __device__ constexpr int8_t type_convert(bhalf_t x) -{ - float x_fp32 = type_convert(x); - - return static_cast(x_fp32); -} - -// convert int8 to bfp16 via fp32 -template <> -inline __host__ __device__ constexpr bhalf_t type_convert(int8_t x) -{ - float x_fp32 = static_cast(x); - - return type_convert(x_fp32); -} - -// Declare a template function for bf16 conversion using RTN -template -__host__ __device__ constexpr Y bf16_convert_rtn(X x); - -// Convert fp32 to bf16 with RTN if higher precision is needed -template <> -inline __host__ __device__ constexpr bhalf_t bf16_convert_rtn(float x) -{ - union - { - float fp32; - uint32_t int32; - } u = {x}; - - // When the exponent bits are not all 1s, then the value is zero, normal, - // or subnormal. We round the bfloat16 mantissa up by adding 0x7FFF, plus - // 1 if the least significant bit of the bfloat16 mantissa is 1 (odd). - // This causes the bfloat16's mantissa to be incremented by 1 if the 16 - // least significant bits of the float mantissa are greater than 0x8000, - // or if they are equal to 0x8000 and the least significant bit of the - // bfloat16 mantissa is 1 (odd). This causes it to be rounded to even when - // the lower 16 bits are exactly 0x8000. If the bfloat16 mantissa already - // has the value 0x7f, then incrementing it causes it to become 0x00 and - // the exponent is incremented by one, which is the next higher FP value - // to the unrounded bfloat16 value. When the bfloat16 value is subnormal - // with an exponent of 0x00 and a mantissa of 0x7f, it may be rounded up - // to a normal value with an exponent of 0x01 and a mantissa of 0x00. - // When the bfloat16 value has an exponent of 0xFE and a mantissa of 0x7F, - // incrementing it causes it to become an exponent of 0xFF and a mantissa - // of 0x00, which is Inf, the next higher value to the unrounded value. - bool flag0 = ~u.int32 & 0x7f800000; - - // When all of the exponent bits are 1, the value is Inf or NaN. - // Inf is indicated by a zero mantissa. NaN is indicated by any nonzero - // mantissa bit. Quiet NaN is indicated by the most significant mantissa - // bit being 1. Signaling NaN is indicated by the most significant - // mantissa bit being 0 but some other bit(s) being 1. If any of the - // lower 16 bits of the mantissa are 1, we set the least significant bit - // of the bfloat16 mantissa, in order to preserve signaling NaN in case - // the bfloat16's mantissa bits are all 0. - bool flag1 = !flag0 && (u.int32 & 0xffff); - - u.int32 += flag0 ? 0x7fff + ((u.int32 >> 16) & 1) : 0; // Round to nearest, round to even - u.int32 |= flag1 ? 0x10000 : 0x0; // Preserve signaling NaN - - return uint16_t(u.int32 >> 16); -} - -// convert fp16 to bfp16 via fp32 with RTN if higher precision is needed -template <> -inline __host__ __device__ constexpr bhalf_t bf16_convert_rtn(half_t x) -{ - float x_fp32 = static_cast(x); - - return bf16_convert_rtn(x_fp32); -} +// f8 +using f8x2_t = typename vector_type::type; +using f8x4_t = typename vector_type::type; +using f8x8_t = typename vector_type::type; +using f8x16_t = typename vector_type::type; +using f8x32_t = typename vector_type::type; +using f8x64_t = typename vector_type::type; template struct NumericLimits @@ -1136,4 +1006,21 @@ struct NumericLimits }; #endif // CK_EXPERIMENTAL_BIT_INT_EXTENSION_INT4 +template <> +struct NumericLimits +{ + static constexpr uint8_t binary_min = 0x08; // 0b00001000 + static constexpr uint8_t binary_max = 0x77; // 0b01110111 + static constexpr uint8_t binary_lowest = 0xF7; // 0b11110111 + static constexpr uint8_t binary_qnan = 0x80; // 0b10000000 + + __host__ __device__ static constexpr f8_t Min() { return bit_cast(binary_min); } + + __host__ __device__ static constexpr f8_t Max() { return bit_cast(binary_max); } + + __host__ __device__ static constexpr f8_t Lowest() { return bit_cast(binary_lowest); } + + __host__ __device__ static constexpr f8_t QuietNaN() { return bit_cast(binary_qnan); } +}; + } // namespace ck diff --git a/include/ck/utility/debug.hpp b/include/ck/utility/debug.hpp index 593bbb711672f7dc6e3db22d877bf249b5dff56d..80346f0d9f6f9e5a6a28dcab4e8a7666639394ca 100644 --- a/include/ck/utility/debug.hpp +++ b/include/ck/utility/debug.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #ifndef UTILITY_DEBUG_HPP #define UTILITY_DEBUG_HPP diff --git a/include/ck/utility/dynamic_buffer.hpp b/include/ck/utility/dynamic_buffer.hpp index c6f0d299ef3c35c63ad97424888545edd2874914..02d61f34ed5192480d09f4432eec30e65e7264b9 100644 --- a/include/ck/utility/dynamic_buffer.hpp +++ b/include/ck/utility/dynamic_buffer.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #pragma once @@ -19,7 +19,8 @@ namespace ck { template + bool InvalidElementUseNumericalZeroValue, + AmdBufferCoherenceEnum coherence = AmdBufferCoherenceEnum::DefaultCoherence> struct DynamicBuffer { using type = T; @@ -77,13 +78,16 @@ struct DynamicBuffer if constexpr(InvalidElementUseNumericalZeroValue) { - return amd_buffer_load_invalid_element_return_zero, t_per_x>( + return amd_buffer_load_invalid_element_return_zero, + t_per_x, + coherence>( p_data_, i, is_valid_element, element_space_size_); } else { return amd_buffer_load_invalid_element_return_customized_value, - t_per_x>( + t_per_x, + coherence>( p_data_, i, is_valid_element, element_space_size_, invalid_element_value_); } } @@ -173,7 +177,7 @@ struct DynamicBuffer { constexpr index_t t_per_x = scalar_per_x_vector / scalar_per_t_vector; - amd_buffer_store, t_per_x>( + amd_buffer_store, t_per_x, coherence>( x, p_data_, i, is_valid_element, element_space_size_); } else if constexpr(GetAddressSpace() == AddressSpaceEnum::Lds && @@ -376,14 +380,19 @@ struct DynamicBuffer __host__ __device__ static constexpr bool IsDynamicBuffer() { return true; } }; -template +template __host__ __device__ constexpr auto make_dynamic_buffer(T* p, ElementSpaceSize element_space_size) { - return DynamicBuffer{p, element_space_size}; + return DynamicBuffer{ + p, element_space_size}; } template < AddressSpaceEnum BufferAddressSpace, + AmdBufferCoherenceEnum coherence = AmdBufferCoherenceEnum::DefaultCoherence, typename T, typename ElementSpaceSize, typename X, @@ -391,7 +400,7 @@ template < __host__ __device__ constexpr auto make_dynamic_buffer(T* p, ElementSpaceSize element_space_size, X invalid_element_value) { - return DynamicBuffer{ + return DynamicBuffer{ p, element_space_size, invalid_element_value}; } diff --git a/include/ck/utility/enable_if.hpp b/include/ck/utility/enable_if.hpp index 297434b0dddd5f1a680176c3bd099bedfda8dff4..c0a3c99f1fdafea9f151fe9fc319c2f7aaa0ffda 100644 --- a/include/ck/utility/enable_if.hpp +++ b/include/ck/utility/enable_if.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/include/ck/utility/f8_utils.hpp b/include/ck/utility/f8_utils.hpp new file mode 100644 index 0000000000000000000000000000000000000000..bb13f98154ef4ec072529c26ea993f8e1a444abf --- /dev/null +++ b/include/ck/utility/f8_utils.hpp @@ -0,0 +1,250 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck/utility/data_type.hpp" + +namespace ck { + +// fp8 rounding modes +// use standard for rounding to nearest, the faster one +// use stochastic for stochastic rounding, helps to avoid error accumulation +enum class f8_rounding_mode +{ + standard, + stochastic +}; + +} // namespace ck + +namespace ck::utils { + +namespace { + +template +__host__ __device__ f8_t run_cast_to_f8(T x, uint32_t rng) +{ + // check data type + constexpr bool is_half = std::is_same::value; + constexpr bool is_float = std::is_same::value; + + // fp8 exponent/mantissa layout + constexpr int f8_exp = 4; + constexpr int f8_mant = 3; + + // resulting type exponent/mantissa layout + constexpr int type_exp = is_half ? 5 : 8; + constexpr int type_mant = is_half ? 10 : 23; + + int exponent; + uint32_t head, mantissa, sign; + // nan code is same for float and half + constexpr uint8_t nan_code = 0x80; + constexpr uint32_t nan_mask = is_half ? 0x7C00 : 0x7F800000; + + // convert to bitwise + typedef typename std::conditional::value, uint16_t, uint32_t>::type + T_bitwise; + T_bitwise x_bitwise = *(reinterpret_cast(&x)); + + // unpack the input, depends on datatype + if constexpr(is_float) + { + head = x_bitwise & 0xFF800000; + mantissa = x_bitwise & 0x7FFFFF; + exponent = (head >> type_mant) & 0xFF; + sign = head >> (type_exp + type_mant); + } + else if constexpr(is_half) + { + head = x_bitwise & 0xFC00; + mantissa = x_bitwise & 0x3FF; + exponent = (head >> type_mant) & 0x1F; + sign = head >> (type_exp + type_mant); + } + + uint32_t signed_inf = (sign << (type_exp + type_mant)) + (((1 << type_exp) - 1) << type_mant); + uint32_t drop_mask = (1 << (type_mant - f8_mant)) - 1; + constexpr int max_exp = (1 << f8_exp) - (negative_zero_nan ? 1 : 2); + constexpr int exp_low_cutoff = + (1 << (type_exp - 1)) - (1 << (f8_exp - 1)) + 1 - (negative_zero_nan ? 1 : 0); + + if constexpr(negative_zero_nan) + { + if((x_bitwise & nan_mask) == nan_mask) + return nan_code; + } + else + { + if((x_bitwise & nan_mask) == nan_mask) + return signed_inf + (mantissa != 0 ? 1 : 0); + } + + // check if x is 0.0 + if(x_bitwise == 0) + return 0; + + exponent -= exp_low_cutoff - 1; + if(exponent <= 0) + drop_mask = (1 << (type_mant - f8_mant + 1 - exponent)) - 1; + mantissa += 1 << type_mant; + // apply random number if needed + mantissa += (stoch ? rng : mantissa) & drop_mask; + if(mantissa >= (2 << type_mant)) + { + mantissa >>= 1; + exponent++; + } + mantissa >>= (type_mant - f8_mant); + + // check negative exponent + if(exponent <= 0) + { + if(x_bitwise == 0) + return 0; + else + { + // subnormal range; represented by a subnormal float8 (exponent 0) + // and involves loss of accuracy + mantissa >>= 1 - exponent; + exponent = 0; + } + } + // above range: quantize to maximum possible float of the same sign + else if(exponent > max_exp) + { + if(clip) + { + mantissa = (1 << f8_mant) - 1; + exponent = max_exp; + } + else + { + return signed_inf; + } + } + + // check if x is 0.0 or -0.0 + if(exponent == 0 && mantissa == 0) + return negative_zero_nan ? 0 : (sign << (f8_exp + f8_mant)); + mantissa &= (1 << f8_mant) - 1; + return (sign << (f8_exp + f8_mant)) | (exponent << f8_mant) | mantissa; +} + +template +__host__ __device__ T run_cast_from_f8(f8_t x) +{ + // check data type + constexpr bool is_half = std::is_same::value; + constexpr bool is_float = std::is_same::value; + + // fp8 exponent/mantissa layout + constexpr int f8_exp = 4; + constexpr int f8_mant = 3; + + // resulting type exponent/mantissa layout + constexpr int type_exp = is_half ? 5 : 8; + constexpr int type_mant = is_half ? 10 : 23; + + // prepare the codes + constexpr uint8_t nan_code = 0x80; + T fInf, fNegInf, fNaN, fNeg0; + if constexpr(is_half) + { + constexpr uint16_t ihInf = 0x7C00; + constexpr uint16_t ihNegInf = 0xFC00; + constexpr uint16_t ihNaN = 0x7C01; + constexpr uint16_t ihNeg0 = 0x8000; + fInf = *(reinterpret_cast(&ihInf)); + fNegInf = *(reinterpret_cast(&ihNegInf)); + fNaN = *(reinterpret_cast(&ihNaN)); + fNeg0 = *(reinterpret_cast(&ihNeg0)); + } + else if constexpr(is_float) + { + constexpr uint32_t ifInf = 0x7F800000; + constexpr uint32_t ifNegInf = 0xFF800000; + constexpr uint32_t ifNaN = 0x7F800001; + constexpr uint32_t ifNeg0 = 0x80000000; + fInf = *(reinterpret_cast(&ifInf)); + fNegInf = *(reinterpret_cast(&ifNegInf)); + fNaN = *(reinterpret_cast(&ifNaN)); + fNeg0 = *(reinterpret_cast(&ifNeg0)); + } + + // unpack the input + uint32_t sign = x >> (f8_exp + f8_mant); + uint32_t mantissa = x & ((1 << f8_mant) - 1); + int exponent = (x & 0x7F) >> f8_mant; + + constexpr int exp_low_cutoff = + (1 << (type_exp - 1)) - (1 << (f8_exp - 1)) + 1 - (negative_zero_nan ? 1 : 0); + typename std::conditional::value, uint16_t, uint32_t>::type retval; + + if constexpr(negative_zero_nan) + { + if(x == nan_code) + return fNaN; + } + else + { + if(x == nan_code) + return fNeg0; + if(exponent == ((1 << f8_exp) - 1)) + return (mantissa == 0) ? (sign ? fNegInf : fInf) : fNaN; + } + + // subnormal input + if(exponent == 0) + { + // guaranteed mantissa!=0 since cases 0x0 and 0x80 are handled above + int sh = 1 + __builtin_clz(mantissa) - ((1 + type_exp + type_mant) - f8_mant); + mantissa <<= sh; + mantissa &= ((1 << f8_mant) - 1); + exponent += 1 - sh; + } + exponent += exp_low_cutoff - 1; + mantissa <<= type_mant - f8_mant; + + // subnormal output (occurs when T=half, we=5, negative_zero_nan=true) + if(exponent <= 0) + { + mantissa |= 1 << type_mant; + mantissa >>= 1 - exponent; + exponent = 0; + } + + retval = (sign << (type_exp + type_mant)) | (exponent << type_mant) | mantissa; + return *(reinterpret_cast(&retval)); +} + +} // namespace + +template +__host__ __device__ f8_t cast_to_f8(T x, uint32_t rng) +{ + // check datatype + constexpr bool is_half = std::is_same::value; + constexpr bool is_float = std::is_same::value; + static_assert(is_half || is_float, "Only half and float can be casted to f8."); + + return run_cast_to_f8(x, rng); +} + +template +__host__ __device__ T cast_from_f8(f8_t x) +{ + // check datatype + constexpr bool is_half = std::is_same::value; + constexpr bool is_float = std::is_same::value; + static_assert(is_half || is_float, "only half and float are supported."); + + // check if x is 0.0 + if(x == 0) + return static_cast(0); + + return run_cast_from_f8(x); +} + +} // namespace ck::utils diff --git a/include/ck/utility/functional.hpp b/include/ck/utility/functional.hpp index 08e730782f386cf5788c64bc04a008dc6cb37b28..91797d24092e3e32ad4a6bd40958952b124d9978 100644 --- a/include/ck/utility/functional.hpp +++ b/include/ck/utility/functional.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/include/ck/utility/functional2.hpp b/include/ck/utility/functional2.hpp index 6f125ca4c944777d02ae1083334bec4602ad68c4..99c65f4eb85b67231557b1916fae10c6568e676d 100644 --- a/include/ck/utility/functional2.hpp +++ b/include/ck/utility/functional2.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/include/ck/utility/functional3.hpp b/include/ck/utility/functional3.hpp index 06b67ef7e3fdec25a1ab2e18e23c904342233ebe..97605a7adeb8ae64e4e5a32debe9386295923b6c 100644 --- a/include/ck/utility/functional3.hpp +++ b/include/ck/utility/functional3.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/include/ck/utility/functional4.hpp b/include/ck/utility/functional4.hpp index 6eeaf15c9b7ac283f7fcac96c195d28f179b6065..b5f3df8d7c517dfaf01320e41721da174883c2d9 100644 --- a/include/ck/utility/functional4.hpp +++ b/include/ck/utility/functional4.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #ifndef CK_FUNCTIONAL4_HPP #define CK_FUNCTIONAL4_HPP diff --git a/include/ck/utility/generic_memory_space_atomic.hpp b/include/ck/utility/generic_memory_space_atomic.hpp index 6a1ca966521710ecdb4ceaad9a83457031f4c508..98f40a4363aa2ddf2908fa497151f756b77d6f94 100644 --- a/include/ck/utility/generic_memory_space_atomic.hpp +++ b/include/ck/utility/generic_memory_space_atomic.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #pragma once #include "data_type.hpp" diff --git a/include/ck/utility/get_id.hpp b/include/ck/utility/get_id.hpp index 44ff438155d2a9c9ba9b0925d4e9fe95c1fa6bce..77564c6130baf45bfd331e69fc437fb7c7c96d18 100644 --- a/include/ck/utility/get_id.hpp +++ b/include/ck/utility/get_id.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/include/ck/utility/get_shift.hpp b/include/ck/utility/get_shift.hpp new file mode 100644 index 0000000000000000000000000000000000000000..0a93081cfd01ce41cf305e4d69e127053de788cc --- /dev/null +++ b/include/ck/utility/get_shift.hpp @@ -0,0 +1,20 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +namespace ck { + +template +static constexpr __device__ index_t get_shift() +{ + return (get_shift() + 1); +}; + +template <> +constexpr __device__ index_t get_shift<1>() +{ + return (0); +} + +} // namespace ck diff --git a/include/ck/utility/ignore.hpp b/include/ck/utility/ignore.hpp index ac33cbf9a508f5e6402d17d6285ae34a8196f6eb..f70a182fd4e5c6cc623acb2c65b5551b2a1acd14 100644 --- a/include/ck/utility/ignore.hpp +++ b/include/ck/utility/ignore.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/include/ck/utility/inner_product.hpp b/include/ck/utility/inner_product.hpp index b65640bfffed31e319549b7152f0c2a93bfaded9..b58b2b33191e23db2bda3f9d15f647f6d4f85f58 100644 --- a/include/ck/utility/inner_product.hpp +++ b/include/ck/utility/inner_product.hpp @@ -1,8 +1,9 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #pragma once #include "data_type.hpp" +#include "type_convert.hpp" namespace ck { @@ -12,13 +13,13 @@ __device__ void inner_product(const TA& a, const TB& b, TC& c); template <> __device__ void inner_product(const float& a, const float& b, float& c) { -#if CK_USE_AMD_INNER_PRODUCT_INLINE_ASM && defined(CK_USE_AMD_V_MAC_F32) +#if CK_USE_AMD_V_MAC_INLINE_ASM && defined(CK_USE_AMD_V_MAC_F32) asm volatile("\n \ v_mac_f32 %0, %1, %2 \n \ " : "=v"(c) : "v"(a), "v"(b), "0"(c)); -#elif CK_USE_AMD_INNER_PRODUCT_INLINE_ASM && defined(CK_USE_AMD_V_FMAC_F32) +#elif CK_USE_AMD_V_MAC_INLINE_ASM && defined(CK_USE_AMD_V_FMAC_F32) asm volatile("\n \ v_fmac_f32 %0, %1, %2 \n \ " @@ -75,22 +76,26 @@ template <> __device__ void inner_product(const half2_t& a, const half2_t& b, float& c) { #if defined(CK_USE_AMD_V_DOT2_F32_F16) -#if CK_USE_AMD_INNER_PRODUCT_INLINE_ASM +#if CK_USE_AMD_V_DOT_INLINE_ASM + // Use 3 x s_nop to avoid hazard (mi200 cdna2 isa page 47 + // https://www.amd.com/system/files/TechDocs/instinct-mi200-cdna2-instruction-set-architecture.pdf + // ) s_nop with parameter 2 is equal to 3 x s_nop asm volatile("\n \ v_dot2_f32_f16 %0, %1, %2, %0\n \ + s_nop 2 \n \ " : "=v"(c) : "v"(a), "v"(b), "0"(c)); #else - c = __builtin_amdgcn_sdot2(a, b, c, false); + c = __builtin_amdgcn_fdot2(a, b, c, false); #endif #else const vector_type a_vector{a}; const vector_type b_vector{b}; static_for<0, 2, 1>{}([&](auto i) { - c += type_convert(a_vector.AsType()[i]) * - type_convert(b_vector.AsType()[i]); + c += type_convert(a_vector.AsType()[i]) * + type_convert(b_vector.AsType()[i]); }); #endif } @@ -162,9 +167,13 @@ __device__ void inner_product(const int8x4_t& a, const int8x4_t& b, int32_t& c) { #if defined(CK_USE_AMD_V_DOT4_I32_I8) -#if CK_USE_AMD_INNER_PRODUCT_INLINE_ASM +#if CK_USE_AMD_V_DOT_INLINE_ASM + // Use 3 x s_nop to avoid hazard (mi200 cdna2 isa page 47 + // https://www.amd.com/system/files/TechDocs/instinct-mi200-cdna2-instruction-set-architecture.pdf + // ) s_nop with parameter 2 is equal to 3 x s_nop asm volatile("\n \ v_dot4_i32_i8 %0, %1, %2, %0\n \ + s_nop 2 \n \ " : "=v"(c) : "v"(bit_cast(a)), "v"(bit_cast(b)), "0"(c)); diff --git a/include/ck/utility/inner_product_dpp8.hpp b/include/ck/utility/inner_product_dpp8.hpp new file mode 100644 index 0000000000000000000000000000000000000000..f079e2ca6490676d6c0e83e9e74dba205a8d6583 --- /dev/null +++ b/include/ck/utility/inner_product_dpp8.hpp @@ -0,0 +1,146 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "amd_gemm_dpp.hpp" +#include "data_type.hpp" +#include "type_convert.hpp" + +namespace ck { + +namespace dpp8 { + +/// Number of lanes that can share data using DPP8 modifiers. +constexpr index_t lane_group_size = 8; + +template +__device__ void inline_v_dot2c_dpp8_instr(const half2_t& a, const half2_t& b, float& c); + +// clang-format off +template <> +__device__ void inline_v_dot2c_dpp8_instr<0>(const half2_t& a, const half2_t& b, float& c){ + asm volatile("\n v_dot2c_f32_f16_dpp %0, %1, %2 dpp8:[0, 0, 0, 0, 0, 0, 0, 0]" : "=v"(c) : "v"(a), "v"(b), "0"(c)); +} +template <> +__device__ void inline_v_dot2c_dpp8_instr<1>(const half2_t& a, const half2_t& b, float& c){ + asm volatile("\n v_dot2c_f32_f16_dpp %0, %1, %2 dpp8:[1, 1, 1, 1, 1, 1, 1, 1]" : "=v"(c) : "v"(a), "v"(b), "0"(c)); +} +template <> +__device__ void inline_v_dot2c_dpp8_instr<2>(const half2_t& a, const half2_t& b, float& c){ + asm volatile("\n v_dot2c_f32_f16_dpp %0, %1, %2 dpp8:[2, 2, 2, 2, 2, 2, 2, 2]" : "=v"(c) : "v"(a), "v"(b), "0"(c)); +} +template <> +__device__ void inline_v_dot2c_dpp8_instr<3>(const half2_t& a, const half2_t& b, float& c){ + asm volatile("\n v_dot2c_f32_f16_dpp %0, %1, %2 dpp8:[3, 3, 3, 3, 3, 3, 3, 3]" : "=v"(c) : "v"(a), "v"(b), "0"(c)); +} +template <> +__device__ void inline_v_dot2c_dpp8_instr<4>(const half2_t& a, const half2_t& b, float& c){ + asm volatile("\n v_dot2c_f32_f16_dpp %0, %1, %2 dpp8:[4, 4, 4, 4, 4, 4, 4, 4]" : "=v"(c) : "v"(a), "v"(b), "0"(c)); +} +template <> +__device__ void inline_v_dot2c_dpp8_instr<5>(const half2_t& a, const half2_t& b, float& c){ + asm volatile("\n v_dot2c_f32_f16_dpp %0, %1, %2 dpp8:[5, 5, 5, 5, 5, 5, 5, 5]" : "=v"(c) : "v"(a), "v"(b), "0"(c)); +} +template <> +__device__ void inline_v_dot2c_dpp8_instr<6>(const half2_t& a, const half2_t& b, float& c){ + asm volatile("\n v_dot2c_f32_f16_dpp %0, %1, %2 dpp8:[6, 6, 6, 6, 6, 6, 6, 6]" : "=v"(c) : "v"(a), "v"(b), "0"(c)); +} +template <> +__device__ void inline_v_dot2c_dpp8_instr<7>(const half2_t& a, const half2_t& b, float& c){ + asm volatile("\n v_dot2c_f32_f16_dpp %0, %1, %2 dpp8:[7, 7, 7, 7, 7, 7, 7, 7]" : "=v"(c) : "v"(a), "v"(b), "0"(c)); +} +// clang-format on + +/** + * Dot product of two vectors using `v_dot` instruction with DPP8 submitted as inline assembly. + */ +template +__device__ void inline_v_dot2c_dpp8(const half2_t& a, const half2_t& b, float& c) +{ + static_assert(SrcLaneIdx >= 0 && SrcLaneIdx < dpp8::lane_group_size, + "DPP8 src broadcast lane out of range <0, 7>."); + if constexpr(ShareA) + { + inline_v_dot2c_dpp8_instr(a, b, c); + } + else + { + inline_v_dot2c_dpp8_instr(b, a, c); + } +} + +/** + * DPP8 instrinsics expects to get an integer mask, hardcoding integers for specific broadcast + * patters. + */ +constexpr std::array IntrinsicMaskDpp8 = { + 0, // 0, 0, 0, 0, 0, 0, 0, 0 + 2396745, // 1, 1, 1, 1, 1, 1, 1, 1 + 4793490, // 2, 2, 2, 2, 2, 2, 2, 2 + 7190235, // 3, 3, 3, 3, 3, 3, 3, 3 + 9586980, // 4, 4, 4, 4, 4, 4, 4, 4 + 11983725, // 5, 5, 5, 5, 5, 5, 5, 5 + 14380470, // 6, 6, 6, 6, 6, 6, 6, 6 + 16777215, // 7, 7, 7, 7, 7, 7, 7, 7 +}; + +/** + * Returns DPP8 sel modifier as an integer required for the intrinsic instruction. + */ +template +constexpr int get_dpp_sel_mask_broadcast() +{ + static_assert(SrcLaneIdx >= 0 && SrcLaneIdx < dpp8::lane_group_size, + "DPP8 src broadcast lane out of range <0, 7>."); + return IntrinsicMaskDpp8[SrcLaneIdx]; +} + +template +__device__ void intrinsic_fdot2_impl(const half2_t& a, const half2_t& b, float& c) +{ + constexpr int sel_mask = get_dpp_sel_mask_broadcast(); + const half2_t val_from_other_lane = + bit_cast(__builtin_amdgcn_mov_dpp8(bit_cast(a), sel_mask)); + c = __builtin_amdgcn_fdot2(val_from_other_lane, b, c, false); +} + +/** + * Dot product of two vectors using `v_dot` instruction with DPP8 submitted using intrinsics. + */ +template +__device__ void intrinsic_fdot2(const half2_t& a, const half2_t& b, float& c) +{ + if constexpr(ShareA) + { + intrinsic_fdot2_impl(a, b, c); + } + else + { + intrinsic_fdot2_impl(b, a, c); + } +} + +/** + * Dot product of two input vectors `a`, `b` using `v_dot` instructions with DPP modifier. + * + * DPP modifier allows us to share one of the vectors between lanes in a lane group. + * When `ShareA` is set, instruction uses vector `a` from lane `SrcLaneIdx` from the same + * lane group (8 lanes per lane group in DPP8). When `ShareA` is not set, vector `b` is shared. + * Note that all the threads in a lane group uses the same vector - broadcast pattern. + * + * `SrcLaneIdx` must be in range from 0 to 7. + */ +template +__device__ void inner_product_dpp(const TA& a, const TB& b, TC& c) +{ +#if CK_USE_AMD_V_DOT_DPP8_INLINE_ASM + inline_v_dot2c_dpp8(a, b, c); +#else + intrinsic_fdot2(a, b, c); +#endif +} + +} // namespace dpp8 + +} // namespace ck diff --git a/include/ck/utility/integral_constant.hpp b/include/ck/utility/integral_constant.hpp index 9aab4e24214a884bdfb391e0f9040bb3af2630ec..376070eb3d8ac326603b71e52e76949c168f4219 100644 --- a/include/ck/utility/integral_constant.hpp +++ b/include/ck/utility/integral_constant.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/include/ck/utility/is_known_at_compile_time.hpp b/include/ck/utility/is_known_at_compile_time.hpp index 8198154422e5d5228b79c6539bddcd5070d1e25c..2cafc3e6f2fafd247e378446a15a7e16c019c914 100644 --- a/include/ck/utility/is_known_at_compile_time.hpp +++ b/include/ck/utility/is_known_at_compile_time.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/include/ck/utility/loop_scheduler.hpp b/include/ck/utility/loop_scheduler.hpp new file mode 100644 index 0000000000000000000000000000000000000000..b2eb0ddb9375a73299a46e43a2632a338f09b7ee --- /dev/null +++ b/include/ck/utility/loop_scheduler.hpp @@ -0,0 +1,26 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck/utility/common_header.hpp" +#include "ck/tensor_description/tensor_adaptor.hpp" + +namespace ck { + +enum struct LoopScheduler +{ + Default, + Interwave, +}; + +constexpr LoopScheduler make_default_loop_scheduler() +{ +#if CK_EXPERIMENTAL_DEFAULT_TO_INTER_WAVE_SCHEDULING + return LoopScheduler::Interwave; +#else + return LoopScheduler::Default; +#endif // if CK_EXPERIMENTAL_DEFAULT_TO_INTER_WAVE_SCHEDULING +} + +} // namespace ck diff --git a/include/ck/utility/magic_division.hpp b/include/ck/utility/magic_division.hpp index a5e8e9216519074d6542d379512c5b8d13ec21c9..1d1f914c6653c1fadd7c722aa4ee4ea5cfd851df 100644 --- a/include/ck/utility/magic_division.hpp +++ b/include/ck/utility/magic_division.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #pragma once @@ -157,4 +157,76 @@ struct MagicDivision } }; +struct MDiv +{ + // 1 dword -> 3 dword storage + uint32_t divisor; + uint32_t multiplier; + uint32_t shift; // TODO: 8 bit is enough + + // prefer construct on host + __host__ __device__ MDiv(uint32_t divisor_) : divisor(divisor_) + { + auto tmp = MagicDivision::CalculateMagicNumbers(divisor_); + + multiplier = tmp[Number<0>{}]; + shift = tmp[Number<1>{}]; + } + + __host__ __device__ MDiv() : divisor(0), multiplier(0), shift(0) {} + + __host__ __device__ void update(uint32_t divisor_) + { + divisor = divisor_; + auto tmp = MagicDivision::CalculateMagicNumbers(divisor_); + + multiplier = tmp[Number<0>{}]; + shift = tmp[Number<1>{}]; + } + + __host__ __device__ uint32_t div(uint32_t dividend_) const + { + return MagicDivision::DoMagicDivision(dividend_, multiplier, shift); + } + + __host__ __device__ void + divmod(uint32_t dividend_, uint32_t& quotient_, uint32_t& remainder_) const + { + quotient_ = div(dividend_); + remainder_ = dividend_ - (quotient_ * divisor); + } + + __host__ __device__ uint32_t get() const { return divisor; } +}; + +struct MDiv2 +{ + // 1 dword -> 2 dword storage, divisor need compute from runtime + uint32_t multiplier; + uint32_t shift; // TODO: 8 bit is enough + + // prefer construct on host + __host__ __device__ MDiv2(uint32_t divisor_) + { + auto tmp = MagicDivision::CalculateMagicNumbers(divisor_); + + multiplier = tmp[Number<0>{}]; + shift = tmp[Number<1>{}]; + } + + __host__ __device__ MDiv2() : multiplier(0), shift(0) {} + + __host__ __device__ uint32_t div(uint32_t dividend_) const + { + return MagicDivision::DoMagicDivision(dividend_, multiplier, shift); + } + + __host__ __device__ void + divmod(uint32_t dividend_, uint32_t divisor_, uint32_t& quotient_, uint32_t& remainder_) const + { + quotient_ = div(dividend_); + remainder_ = dividend_ - (quotient_ * divisor_); + } +}; + } // namespace ck diff --git a/include/ck/utility/math.hpp b/include/ck/utility/math.hpp index 72071992f65d934e91a6976c2cce3084b4a2d4ca..c5e967c8f4940039b5dd4f59e1788e764518d552 100644 --- a/include/ck/utility/math.hpp +++ b/include/ck/utility/math.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #pragma once @@ -240,5 +240,21 @@ struct less __host__ __device__ constexpr bool operator()(T x, T y) const { return x < y; } }; +template +__host__ __device__ constexpr auto next_power_of_two() +{ + // TODO: X need to be 2 ~ 0x7fffffff. 0, 1, or larger than 0x7fffffff will compile fail + constexpr index_t Y = 1 << (32 - __builtin_clz(X - 1)); + return Y; +} + +template +__host__ __device__ constexpr auto next_power_of_two(Number x) +{ + // TODO: X need to be 2 ~ 0x7fffffff. 0, 1, or larger than 0x7fffffff will compile fail + constexpr index_t Y = 1 << (32 - __builtin_clz(x.value - 1)); + return Number{}; +} + } // namespace math } // namespace ck diff --git a/include/ck/utility/math_v2.hpp b/include/ck/utility/math_v2.hpp index a3732b2fe0f0d813a281cefca403f317a6daef81..1cac2cc0c7172b4b1915845d6c7f3fa853653683 100644 --- a/include/ck/utility/math_v2.hpp +++ b/include/ck/utility/math_v2.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/include/ck/utility/multi_index.hpp b/include/ck/utility/multi_index.hpp index 1d544c0906cae3c61f4d7ce27e74e7636c3919b5..9f7ba8bff63b3f6e625b98f66accd875ee118594 100644 --- a/include/ck/utility/multi_index.hpp +++ b/include/ck/utility/multi_index.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/include/ck/utility/number.hpp b/include/ck/utility/number.hpp index f3ca6b61dc6ac08330a6cf1633148bb9bad8cd81..d29afd31a7ef03ef3fdbf11f8e9a88132e07fe82 100644 --- a/include/ck/utility/number.hpp +++ b/include/ck/utility/number.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #ifndef CK_NUMBER_HPP #define CK_NUMBER_HPP diff --git a/include/ck/utility/random_gen.hpp b/include/ck/utility/random_gen.hpp new file mode 100644 index 0000000000000000000000000000000000000000..b7edf26507c62d5365ecab3fe1660c39a2cd672f --- /dev/null +++ b/include/ck/utility/random_gen.hpp @@ -0,0 +1,53 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +namespace ck { + +// Pseudo random number generator +// version for fp32 +template {}, bool> = false> +__host__ __device__ uint32_t prand_generator(index_t id, T val, uint32_t seed = seed_t) +{ + uint32_t x = *(reinterpret_cast(&val)); + uint32_t drop_bits = uint32_t(x) & 0xFFFFu; + drop_bits ^= x >> 16; + drop_bits = ((drop_bits & 31) << 11) | (drop_bits >> 5); + drop_bits *= 0x7000149; + // NOTE: If id is in 64 bit, we are only using lower 32 bit. + // So, it can have an effect of using same id for multiple elements when the id is very + // large! + uint32_t rng = (drop_bits ^ 0x13371337 ^ (id * 229791) ^ seed); + return rng; +} + +// version for fp16 +template {}, bool> = false> +__host__ __device__ uint32_t prand_generator(index_t id, T val, uint32_t seed = seed_t) +{ + uint16_t x = *(reinterpret_cast(&val)); + uint32_t drop_bits = uint32_t(x) & 0xFFFFu; + drop_bits = ((drop_bits & 31) << 11) | (drop_bits >> 5); + drop_bits *= 0x7000149; + // NOTE: If id is in 64 bit, we are only using lower 32 bit. + // So, it can have an effect of using same id for multiple elements when the id is very + // large! + uint32_t rng = (drop_bits ^ 0x13371337 ^ (id * 229791) ^ seed); + return rng; +} + +// return 0 if data is not fp16 or fp32 +template {} || std::is_same{}), bool> = false> +__host__ __device__ uint32_t prand_generator(int id, T val, uint32_t seed = seed_t) +{ + std::ignore = id; + std::ignore = val; + std::ignore = seed; + + return 0; +} + +} // namespace ck diff --git a/include/ck/utility/reduction_common.hpp b/include/ck/utility/reduction_common.hpp index aceef7b296da7466c484aad8033fe1f949662fca..75fdd85825a71184e6f5afa4688a43c6c7a5518a 100644 --- a/include/ck/utility/reduction_common.hpp +++ b/include/ck/utility/reduction_common.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #pragma once @@ -25,16 +25,4 @@ struct float_equal_zero }; }; -template -static constexpr __device__ index_t get_shift() -{ - return (get_shift() + 1); -}; - -template <> -constexpr __device__ index_t get_shift<1>() -{ - return (0); -} - } // namespace ck diff --git a/include/ck/utility/reduction_enums.hpp b/include/ck/utility/reduction_enums.hpp index 67856331059cef628a62e4b2223f262b839bbf02..23b7149f8eb022a00d1aed81b4b69c315ca37ba0 100644 --- a/include/ck/utility/reduction_enums.hpp +++ b/include/ck/utility/reduction_enums.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/include/ck/utility/reduction_functions_accumulate.hpp b/include/ck/utility/reduction_functions_accumulate.hpp index 724e5599d6c878ba08bf14429c40c0958ab9f2a2..b9765ff0d2cb57bb2a1a0c4c7684a3cd2e0fa008 100644 --- a/include/ck/utility/reduction_functions_accumulate.hpp +++ b/include/ck/utility/reduction_functions_accumulate.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/include/ck/utility/reduction_operator.hpp b/include/ck/utility/reduction_operator.hpp index b4e770a64efdda0b42e7996e2858d0f162231fee..5480a98409e3c52ba5d24cbcc930213658ea374b 100644 --- a/include/ck/utility/reduction_operator.hpp +++ b/include/ck/utility/reduction_operator.hpp @@ -1,11 +1,12 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #pragma once #include "ck/ck.hpp" #include "ck/utility/data_type.hpp" #include "ck/utility/type.hpp" +#include "ck/utility/type_convert.hpp" namespace ck { @@ -115,7 +116,15 @@ struct Max template __host__ __device__ static constexpr T GetIdentityValue() { - return NumericLimits::Lowest(); + if constexpr(is_same_v) + { + float val = NumericLimits::Lowest(); + return type_convert(val); + } + else + { + return NumericLimits::Lowest(); + } }; __host__ __device__ static constexpr bool @@ -137,6 +146,15 @@ struct Max a = b; } + __host__ __device__ inline constexpr void operator()(bhalf_t& a, bhalf_t b) const + { + float a_ = type_convert(a); + float b_ = type_convert(b); + + if(a_ < b_) + a = b; + } + template __host__ __device__ inline constexpr void operator()(T& a, T b, bool& changed) const { @@ -151,6 +169,18 @@ struct Max changed = true; } } + + __host__ __device__ inline constexpr void operator()(bhalf_t& a, bhalf_t b, bool& changed) const + { + float a_ = type_convert(a); + float b_ = type_convert(b); + + if(a_ < b_) + { + a = b; + changed = true; + } + } }; struct Min @@ -158,6 +188,15 @@ struct Min template __host__ __device__ static constexpr T GetIdentityValue() { + if constexpr(is_same_v) + { + float val = NumericLimits::Max(); + return type_convert(val); + } + else + { + return NumericLimits::Max(); + } return NumericLimits::Max(); }; @@ -180,6 +219,15 @@ struct Min a = b; } + __host__ __device__ inline constexpr void operator()(bhalf_t& a, bhalf_t b) const + { + float a_ = type_convert(a); + float b_ = type_convert(b); + + if(a_ > b_) + a = b; + } + template __host__ __device__ inline constexpr void operator()(T& a, T b, bool& changed) const { @@ -194,6 +242,18 @@ struct Min changed = true; } } + + __host__ __device__ inline constexpr void operator()(bhalf_t& a, bhalf_t b, bool& changed) const + { + float a_ = type_convert(a); + float b_ = type_convert(b); + + if(a_ > b_) + { + a = b; + changed = true; + } + } }; struct AMax diff --git a/include/ck/utility/sequence.hpp b/include/ck/utility/sequence.hpp index 97b597221c2850d4b28b644266978fc56b295913..d6bfb2eba1cd6a6de2e334261379504d8448b3c7 100644 --- a/include/ck/utility/sequence.hpp +++ b/include/ck/utility/sequence.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/include/ck/utility/sequence_helper.hpp b/include/ck/utility/sequence_helper.hpp index db25c27e70c3653f94524367b8a6bce79480113e..8c493a28221eae4e095f051264f9323429984531 100644 --- a/include/ck/utility/sequence_helper.hpp +++ b/include/ck/utility/sequence_helper.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/include/ck/utility/span.hpp b/include/ck/utility/span.hpp index 1e501214547cfbcec25921d6526e77563504835a..5e7567a847f3d18e424437005d876a409cd0d99d 100644 --- a/include/ck/utility/span.hpp +++ b/include/ck/utility/span.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/include/ck/utility/static_buffer.hpp b/include/ck/utility/static_buffer.hpp index dd25c96203288cfc9aa0959e9a7b1d83ef4abfc2..835f5657307a8e171b116c0bec088bb84b1b7b38 100644 --- a/include/ck/utility/static_buffer.hpp +++ b/include/ck/utility/static_buffer.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/include/ck/utility/statically_indexed_array.hpp b/include/ck/utility/statically_indexed_array.hpp index 3438776f413cf010664cc8aa4a18c09bf161fae7..a2d70045a4b9b221651d958472aa6a2e721ccccc 100644 --- a/include/ck/utility/statically_indexed_array.hpp +++ b/include/ck/utility/statically_indexed_array.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #ifndef CK_STATICALLY_INDEXED_ARRAY_HPP #define CK_STATICALLY_INDEXED_ARRAY_HPP diff --git a/include/ck/utility/statically_indexed_array_multi_index.hpp b/include/ck/utility/statically_indexed_array_multi_index.hpp index 21b2941b21401ac561932a45ec304536365f2927..4a8b96ae8a6b81400f3567957f07010c26c0bf8d 100644 --- a/include/ck/utility/statically_indexed_array_multi_index.hpp +++ b/include/ck/utility/statically_indexed_array_multi_index.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #ifndef CK_STATICALLY_INDEXED_ARRAY_MULTI_INDEX_HPP #define CK_STATICALLY_INDEXED_ARRAY_MULTI_INDEX_HPP diff --git a/include/ck/utility/synchronization.hpp b/include/ck/utility/synchronization.hpp index 0e247ed0f8df15b890cdd44ea9f1682d5823a89f..775e7ac3a392f5edd16054e63314bc80faca5d6e 100644 --- a/include/ck/utility/synchronization.hpp +++ b/include/ck/utility/synchronization.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/include/ck/utility/thread_group.hpp b/include/ck/utility/thread_group.hpp index d469dec899a556c5b8062efa775b7dc694500fe8..1cd6b2f3ce26a9a89e98600f7377aac0d93af344 100644 --- a/include/ck/utility/thread_group.hpp +++ b/include/ck/utility/thread_group.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/include/ck/utility/transpose_vectors.hpp b/include/ck/utility/transpose_vectors.hpp index 2b0075d6005e3c5cf7d6772fda8132a6404878ed..6faf5c133ba2df0ed938aab915a272e3a7a590f3 100644 --- a/include/ck/utility/transpose_vectors.hpp +++ b/include/ck/utility/transpose_vectors.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/include/ck/utility/tuple.hpp b/include/ck/utility/tuple.hpp index d8664be550b79bc4de2edd224bc63350a0a1bde7..b616b3123f3b9ce3c92763cd36ecd39fc7d3e553 100644 --- a/include/ck/utility/tuple.hpp +++ b/include/ck/utility/tuple.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/include/ck/utility/tuple_helper.hpp b/include/ck/utility/tuple_helper.hpp index 6f5b142a5e7a765807bd9d7f556f7b8afc512d37..e39ae1c23d421e380b81b4bc7b487025cb6937be 100644 --- a/include/ck/utility/tuple_helper.hpp +++ b/include/ck/utility/tuple_helper.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/include/ck/utility/type.hpp b/include/ck/utility/type.hpp index 90b9df2950b7d979f8cef386d155c72f6d7a39a5..9609afba43a78c05332d459b3a7cf756931a6e60 100644 --- a/include/ck/utility/type.hpp +++ b/include/ck/utility/type.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/include/ck/utility/type_convert.hpp b/include/ck/utility/type_convert.hpp new file mode 100644 index 0000000000000000000000000000000000000000..65d89403773d7c6cebcd9548a09b0d0404c8176b --- /dev/null +++ b/include/ck/utility/type_convert.hpp @@ -0,0 +1,212 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck/utility/data_type.hpp" +#include "ck/utility/f8_utils.hpp" +#include "ck/utility/random_gen.hpp" + +namespace ck { + +// Convert X to Y +template +__host__ __device__ constexpr Y type_convert(X x) +{ + static_assert(!std::is_reference_v && !std::is_reference_v); + + return static_cast(x); +} + +// convert bfp16 to fp32 +template <> +inline __host__ __device__ constexpr float type_convert(bhalf_t x) +{ + union + { + uint32_t int32; + float fp32; + } u = {uint32_t(x) << 16}; + + return u.fp32; +} + +// convert fp32 to bfp16 +template <> +inline __host__ __device__ constexpr bhalf_t type_convert(float x) +{ + union + { + float fp32; + uint32_t int32; + } u = {x}; + + return uint16_t(u.int32 >> 16); +} + +// convert bfp16 to fp16 via fp32 +template <> +inline __host__ __device__ constexpr half_t type_convert(bhalf_t x) +{ + float x_fp32 = type_convert(x); + + return static_cast(x_fp32); +} + +// convert fp16 to bfp16 via fp32 +template <> +inline __host__ __device__ constexpr bhalf_t type_convert(half_t x) +{ + float x_fp32 = static_cast(x); + + return type_convert(x_fp32); +} + +// convert bfp16 to int8 via fp32 +template <> +inline __host__ __device__ constexpr int8_t type_convert(bhalf_t x) +{ + float x_fp32 = type_convert(x); + + return static_cast(x_fp32); +} + +// convert int8 to bfp16 via fp32 +template <> +inline __host__ __device__ constexpr bhalf_t type_convert(int8_t x) +{ + float x_fp32 = static_cast(x); + + return type_convert(x_fp32); +} + +// convert fp32 to fp8 +template <> +inline __host__ __device__ f8_t type_convert(float x) +{ + constexpr bool negative_zero_nan = true; + constexpr bool clip = true; + constexpr f8_rounding_mode rm = f8_rounding_mode::standard; + constexpr uint32_t rng = 0; + return utils::cast_to_f8( + x, rng); +} + +// convert fp8 to fp32 +template <> +inline __host__ __device__ float type_convert(f8_t x) +{ + constexpr bool negative_zero_nan = true; + return utils::cast_from_f8(x); +} + +// convert fp16 to fp8 +template <> +inline __host__ __device__ f8_t type_convert(half_t x) +{ + constexpr bool negative_zero_nan = true; + constexpr bool clip = true; + constexpr f8_rounding_mode rm = f8_rounding_mode::standard; + constexpr uint32_t rng = 0; + return utils::cast_to_f8( + x, rng); +} + +// convert fp8 to fp16 +template <> +inline __host__ __device__ half_t type_convert(f8_t x) +{ + constexpr bool negative_zero_nan = true; + return utils::cast_from_f8(x); +} + +// Declare a template function for bf16 conversion using RTN +template +__host__ __device__ constexpr Y bf16_convert_rtn(X x); + +// Convert fp32 to bf16 with RTN if higher precision is needed +template <> +inline __host__ __device__ constexpr bhalf_t bf16_convert_rtn(float x) +{ + union + { + float fp32; + uint32_t int32; + } u = {x}; + + // When the exponent bits are not all 1s, then the value is zero, normal, + // or subnormal. We round the bfloat16 mantissa up by adding 0x7FFF, plus + // 1 if the least significant bit of the bfloat16 mantissa is 1 (odd). + // This causes the bfloat16's mantissa to be incremented by 1 if the 16 + // least significant bits of the float mantissa are greater than 0x8000, + // or if they are equal to 0x8000 and the least significant bit of the + // bfloat16 mantissa is 1 (odd). This causes it to be rounded to even when + // the lower 16 bits are exactly 0x8000. If the bfloat16 mantissa already + // has the value 0x7f, then incrementing it causes it to become 0x00 and + // the exponent is incremented by one, which is the next higher FP value + // to the unrounded bfloat16 value. When the bfloat16 value is subnormal + // with an exponent of 0x00 and a mantissa of 0x7f, it may be rounded up + // to a normal value with an exponent of 0x01 and a mantissa of 0x00. + // When the bfloat16 value has an exponent of 0xFE and a mantissa of 0x7F, + // incrementing it causes it to become an exponent of 0xFF and a mantissa + // of 0x00, which is Inf, the next higher value to the unrounded value. + bool flag0 = ~u.int32 & 0x7f800000; + + // When all of the exponent bits are 1, the value is Inf or NaN. + // Inf is indicated by a zero mantissa. NaN is indicated by any nonzero + // mantissa bit. Quiet NaN is indicated by the most significant mantissa + // bit being 1. Signaling NaN is indicated by the most significant + // mantissa bit being 0 but some other bit(s) being 1. If any of the + // lower 16 bits of the mantissa are 1, we set the least significant bit + // of the bfloat16 mantissa, in order to preserve signaling NaN in case + // the bfloat16's mantissa bits are all 0. + bool flag1 = !flag0 && (u.int32 & 0xffff); + + u.int32 += flag0 ? 0x7fff + ((u.int32 >> 16) & 1) : 0; // Round to nearest, round to even + u.int32 |= flag1 ? 0x10000 : 0x0; // Preserve signaling NaN + + return uint16_t(u.int32 >> 16); +} + +// convert fp16 to bfp16 via fp32 with RTN if higher precision is needed +template <> +inline __host__ __device__ constexpr bhalf_t bf16_convert_rtn(half_t x) +{ + float x_fp32 = static_cast(x); + + return bf16_convert_rtn(x_fp32); +} + +// Declare a template function for fp8 conversion using SR +template +__host__ __device__ constexpr Y f8_convert_sr(X x); + +// convert fp32 to fp8 with stochastic rounding +template <> +inline __host__ __device__ f8_t f8_convert_sr(float x) +{ + constexpr bool negative_zero_nan = true; + constexpr bool clip = true; + constexpr f8_rounding_mode rm = f8_rounding_mode::stochastic; + constexpr int seed = 42; + // as thread id is not available on host, use 0 for prn generation + uint32_t rng = prand_generator(reinterpret_cast(&x), x); + return utils::cast_to_f8( + x, rng); +} + +// convert fp16 to fp8 with stochastic rounding +template <> +inline __host__ __device__ f8_t f8_convert_sr(half_t x) +{ + constexpr bool negative_zero_nan = true; + constexpr bool clip = true; + constexpr f8_rounding_mode rm = f8_rounding_mode::stochastic; + constexpr int seed = 42; + // as thread id is not available on host, use 0 for prn generation + uint32_t rng = prand_generator(reinterpret_cast(&x), x); + return utils::cast_to_f8( + x, rng); +} + +} // namespace ck diff --git a/include/ck/utility/workgroup_barrier.hpp b/include/ck/utility/workgroup_barrier.hpp new file mode 100644 index 0000000000000000000000000000000000000000..ec9151fd1bea830b5abba6881f04be6b68f6bbb5 --- /dev/null +++ b/include/ck/utility/workgroup_barrier.hpp @@ -0,0 +1,73 @@ +#pragma once +#include +#include + +namespace ck { +struct workgroup_barrier +{ + __device__ workgroup_barrier(uint32_t* ptr) : base_ptr(ptr) {} + + __device__ uint32_t ld(uint32_t offset) + { +#if 0 + float d = llvm_amdgcn_raw_buffer_load_fp32( + amdgcn_make_buffer_resource(base_ptr), + 0, + offset, + AMDGCN_BUFFER_GLC); + union cvt { + float f32; + uint32_t u32; + }; + cvt x; + x.f32 = d; + return x.u32; +#endif + return __atomic_load_n(base_ptr + offset, __ATOMIC_RELAXED); + } + + __device__ void wait_eq(uint32_t offset, uint32_t value) + { + if(threadIdx.x == 0) + { + while(ld(offset) != value) {} + } + __syncthreads(); + } + + __device__ void wait_lt(uint32_t offset, uint32_t value) + { + if(threadIdx.x == 0) + { + while(ld(offset) < value) {} + } + __syncthreads(); + } + + __device__ void wait_set(uint32_t offset, uint32_t compare, uint32_t value) + { + if(threadIdx.x == 0) + { + while(atomicCAS(base_ptr + offset, compare, value) != compare) {} + } + __syncthreads(); + } + + // enter critical zoon, assume buffer is zero when launch kernel + __device__ void aquire(uint32_t offset) { wait_set(offset, 0, 1); } + + // exit critical zoon, assume buffer is zero when launch kernel + __device__ void release(uint32_t offset) { wait_set(offset, 1, 0); } + + __device__ void inc(uint32_t offset) + { + __syncthreads(); + if(threadIdx.x == 0) + { + atomicAdd(base_ptr + offset, 1); + } + } + + uint32_t* base_ptr; +}; +} // namespace ck diff --git a/include/ck/utility/workgroup_synchronization.hpp b/include/ck/utility/workgroup_synchronization.hpp new file mode 100644 index 0000000000000000000000000000000000000000..24858fdbdc65a978fec736f9c06ca31c5afd0a0a --- /dev/null +++ b/include/ck/utility/workgroup_synchronization.hpp @@ -0,0 +1,74 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once +#include "ck/host_utility/hip_check_error.hpp" + +namespace ck { + +// Initialization flag of Barrier object, can be any value except for zero +static constexpr int BarrierInitFlag = 0x7856; + +// 1) only the first thread-block in the synchronizaton group is supposed to call this function. It +// is the responsibility of the user to ensure the two integer values in p_control_bits are zeros +// before calling gms_init(). +// 2) Aftercalling gms_reset(), the two integer values in p_control_bits will be zeros, so no +// repetitious initialization of p_control_bits buffer is required +static __device__ void gms_init(int NumWarps, int* p_control_bits) +{ + union + { + int two32[2]; + unsigned long one64; + } regs; + + regs.two32[0] = BarrierInitFlag; + regs.two32[1] = NumWarps; + + if(threadIdx.x == 0) + atomicCAS(reinterpret_cast(p_control_bits), 0, regs.one64); +}; + +// all the workgroups in the synchronization group is supposed to call this function +static __device__ void gms_barrier(int* p_control_bits) +{ + constexpr int mask = warpSize - 1; + + if((threadIdx.x & mask) == 0) + { + // ensure the barrier object is initialized + do + { + const int r0 = __atomic_load_n(&p_control_bits[0], __ATOMIC_RELAXED); + + if(r0 == BarrierInitFlag) + break; + + } while(true); + + // go ahead toward the barrier line + atomicSub(&p_control_bits[1], 1); + + // wait until all warps have arrived + do + { + const int r1 = __atomic_load_n(&p_control_bits[1], __ATOMIC_RELAXED); + + if(r1 == 0) + break; + + } while(true); + }; +}; + +// 1) Only the first thread-block in the synchronizaton group is supposed to call this function. +// 2) Aftercalling gms_reset(), the two integer values in p_control_bits will be zeros, so no +// repetitious initialization of p_control_bits buffer is required +static __device__ void gms_reset(int* p_control_bits) +{ + // reset the barrier object + if(threadIdx.x == 0) + (void)atomicCAS(&p_control_bits[0], BarrierInitFlag, 0); +}; + +} // namespace ck diff --git a/include/ck/version.h.in b/include/ck/version.h.in new file mode 100644 index 0000000000000000000000000000000000000000..0d6a6512fb44016661d2391b413b0d071cdce790 --- /dev/null +++ b/include/ck/version.h.in @@ -0,0 +1,40 @@ +/******************************************************************************* + * + * MIT License + * + * Copyright (c) 2023 Advanced Micro Devices, Inc. + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to deal + * in the Software without restriction, including without limitation the rights + * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + * copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in all + * copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + * SOFTWARE. + * + *******************************************************************************/ + +/* the configured version and settings for miopen- Composable Kernel */ + +#ifndef CK_VERSION_H_ +#define CK_VERSION_H_ + +// clang-format off +#define CK_VERSION @CMAKE_PROJECT_VERSION@ +#define CK_VERSION_MAJOR @CMAKE_PROJECT_VERSION_MAJOR@ +#define CK_VERSION_MINOR @CMAKE_PROJECT_VERSION_MINOR@ +#define CK_VERSION_PATCH @CMAKE_PROJECT_VERSION_PATCH@ +#define CK_COMMIT_ID @COMMIT_ID@ +// clang-format on + +#endif diff --git a/library/include/ck/library/reference_tensor_operation/cpu/reference_avgpool_bwd.hpp b/library/include/ck/library/reference_tensor_operation/cpu/reference_avgpool_bwd.hpp new file mode 100644 index 0000000000000000000000000000000000000000..fa06e775602215c1aba0796b0f907fa5748f898a --- /dev/null +++ b/library/include/ck/library/reference_tensor_operation/cpu/reference_avgpool_bwd.hpp @@ -0,0 +1,354 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include +#include + +#include "ck/tensor_operation/gpu/device/device_base.hpp" + +#include "ck/library/utility/host_tensor.hpp" + +namespace ck { +namespace tensor_operation { +namespace host { + +// dinput descriptor in [N, C, Do, Ho, Wo] order +// doutput descriptor in [N, C, Di, Hi, Wi] order +// phyiscal layout is irrelavent +template = 1 && NDimSpatial <= 3, bool>::type = false> +struct ReferenceAvgPoolBwd : public device::BaseOperator +{ + // Argument + struct Argument : public device::BaseArgument + { + Argument(Tensor& dinput, + const Tensor& doutput, + std::vector window_spatial_lengths, + std::vector window_strides, + std::vector window_dilations, + std::vector dinput_left_pads, + std::vector dinput_right_pads) + : dinput_{dinput}, + doutput_{doutput}, + window_spatial_lengths_{window_spatial_lengths}, + window_strides_{window_strides}, + window_dilations_{window_dilations}, + in_left_pads_{dinput_left_pads}, + in_right_pads_{dinput_right_pads} + { + } + + Tensor& dinput_; + const Tensor& doutput_; + + std::vector window_spatial_lengths_; + std::vector window_strides_; + std::vector window_dilations_; + std::vector in_left_pads_; + std::vector in_right_pads_; + }; + + // Invoker + struct Invoker : public device::BaseInvoker + { + using Argument = ReferenceAvgPoolBwd::Argument; + + template ::type = false> + float RunAvgPoolBwd(const Argument& arg) + { + // Let input = x, outpu = y + // shape of x = [10], y = [6] + // window_size = 5, pad = 0, stride = 1, dilation = 1 + // Forward: + // y0 = 1/5 * (x0 + x1 + x2 + x3 + x4) + // y1 = 1/5 * (x1 + x2 + x3 + x4 + x5) + // ... + // y5 = 1/5 * (x5 + x6 + x7 + x8 + x9) + // y6 = 1/5 * (x6 + x7 + x8 + x9) + // ... + // y9 = 1/5 * (x9) + + // Backward: + // shape of dy = [6], dx = [10] + // dx0 = 1/5 * dy0 + // dx1 = 1/5 * (dy0 + dy1) + // dx2 = 1/5 * (dy0 + dy1 + dy2) + // ... + // dx4 = 1/5 * (dy0 + dy1 + dy2 + dy3 + dy4) + // dx5 = 1/5 * (dy1 + dy2 + dy3 + dy4 + dy5) + // ... + // dx9 = 1/5 * (dy5 + dy6 + dy7 + dy8 + dy9) + + auto f_ncw = [&](auto n, auto c, auto wi) { + std::size_t X = arg.window_spatial_lengths_[0]; + std::size_t Wo = arg.doutput_.GetLengths()[2]; + + float v_acc = 0; + + for(std::size_t x = 0; x < X; ++x) + { + // Out_Position = (In_Position + pad - x * dilation) / stride + auto w_tmp = static_cast(wi) + + static_cast(arg.in_left_pads_[0]) - + static_cast(x * arg.window_dilations_[0]); + + // Check the input pixel validity (in perspective of being affected by some + // doutput pixel) + if(w_tmp % arg.window_strides_[0] == 0) + { + auto wo = static_cast(w_tmp) / + static_cast(arg.window_strides_[0]); + + // Get the doutput pixel in valid range to accumulate the gradients for this + // input pixel + if(wo >= 0 && ck::type_convert(wo) < Wo) + { + v_acc += ck::type_convert(arg.doutput_(n, c, wo)); + } + } + } + + v_acc /= ck::type_convert(X); + arg.dinput_(n, c, wi) = ck::type_convert(v_acc); + }; + + make_ParallelTensorFunctor(f_ncw, + arg.dinput_.GetLengths()[0], + arg.dinput_.GetLengths()[1], + arg.dinput_.GetLengths()[2])( + std::thread::hardware_concurrency()); + + return 0; + } + + template ::type = false> + float RunAvgPoolBwd(const Argument& arg) + { + auto f_nchw = [&](auto n, auto c, auto hi, auto wi) { + std::size_t Y = arg.window_spatial_lengths_[0]; + std::size_t X = arg.window_spatial_lengths_[1]; + + std::size_t Ho = arg.doutput_.GetLengths()[2]; + std::size_t Wo = arg.doutput_.GetLengths()[3]; + + float v_acc = 0; + + for(std::size_t y = 0; y < Y; ++y) + { + // Out_Position = (In_Position + pad - x * dilation) / stride + auto h_tmp = static_cast(hi) + + static_cast(arg.in_left_pads_[0]) - + static_cast(y * arg.window_dilations_[0]); + + // Check the input pixel validity (in perspective of being affected by some + // doutput pixel) + if(h_tmp % arg.window_strides_[0] == 0) + { + auto ho = static_cast(h_tmp) / + static_cast(arg.window_strides_[0]); + + // Get the doutput pixel in valid range to accumulate the gradients for this + // input pixel + if(ho >= 0 && ck::type_convert(ho) < Ho) + { + for(std::size_t x = 0; x < X; ++x) + { + auto w_tmp = + static_cast(wi) + + static_cast(arg.in_left_pads_[1]) - + static_cast(x * arg.window_dilations_[1]); + if(w_tmp % arg.window_strides_[1] == 0) + { + auto wo = static_cast(w_tmp) / + static_cast(arg.window_strides_[1]); + if(wo >= 0 && ck::type_convert(wo) < Wo) + { + v_acc += + ck::type_convert(arg.doutput_(n, c, ho, wo)); + } + } + } + } + } + } + + v_acc /= ck::type_convert(Y * X); + arg.dinput_(n, c, hi, wi) = ck::type_convert(v_acc); + }; + + make_ParallelTensorFunctor(f_nchw, + arg.dinput_.GetLengths()[0], + arg.dinput_.GetLengths()[1], + arg.dinput_.GetLengths()[2], + arg.dinput_.GetLengths()[3])( + std::thread::hardware_concurrency()); + + return 0; + } + + template ::type = false> + float RunAvgPoolBwd(const Argument& arg) + { + auto f_ncdhw = [&](auto n, auto c, auto di, auto hi, auto wi) { + std::size_t Z = arg.window_spatial_lengths_[0]; + std::size_t Y = arg.window_spatial_lengths_[1]; + std::size_t X = arg.window_spatial_lengths_[2]; + + std::size_t Do = arg.doutput_.GetLengths()[2]; + std::size_t Ho = arg.doutput_.GetLengths()[3]; + std::size_t Wo = arg.doutput_.GetLengths()[4]; + + float v_acc = 0; + + for(std::size_t z = 0; z < Z; ++z) + { + // Out_Position = (In_Position + pad - x * dilation) / stride + auto d_tmp = static_cast(di) + + static_cast(arg.in_left_pads_[0]) - + static_cast(z * arg.window_dilations_[0]); + + // Check the input pixel validity (in perspective of being affected by some + // doutput pixel) + if(d_tmp % arg.window_strides_[0] == 0) + { + auto do_ = static_cast(d_tmp) / + static_cast(arg.window_strides_[0]); + + // Get the doutput pixel in valid range to accumulate the gradients for this + // input pixel + if(do_ >= 0 && ck::type_convert(do_) < Do) + { + for(std::size_t y = 0; y < Y; ++y) + { + auto h_tmp = + static_cast(hi) + + static_cast(arg.in_left_pads_[1]) - + static_cast(y * arg.window_dilations_[1]); + if(h_tmp % arg.window_strides_[1] == 0) + { + auto ho = static_cast(h_tmp) / + static_cast(arg.window_strides_[1]); + if(ho >= 0 && ck::type_convert(ho) < Ho) + { + for(std::size_t x = 0; x < X; ++x) + { + auto w_tmp = static_cast(wi) + + static_cast( + arg.in_left_pads_[2]) - + static_cast( + x * arg.window_dilations_[2]); + + if(w_tmp % arg.window_strides_[2] == 0) + { + auto wo = static_cast(w_tmp) / + static_cast( + arg.window_strides_[2]); + if(wo >= 0 && + ck::type_convert(wo) < Wo) + { + v_acc += ck::type_convert( + arg.doutput_(n, c, do_, ho, wo)); + } + } + } + } + } + } + } + } + } + + v_acc /= ck::type_convert(Z * Y * X); + arg.dinput_(n, c, di, hi, wi) = ck::type_convert(v_acc); + }; + + make_ParallelTensorFunctor(f_ncdhw, + arg.dinput_.GetLengths()[0], + arg.dinput_.GetLengths()[1], + arg.dinput_.GetLengths()[2], + arg.dinput_.GetLengths()[3], + arg.dinput_.GetLengths()[4])( + std::thread::hardware_concurrency()); + + return 0; + } + + float Run(const Argument& arg) + { + if(!(arg.dinput_.GetNumOfDimension() == NDimSpatial + 2 && + arg.doutput_.GetNumOfDimension() == NDimSpatial + 2)) + { + throw std::runtime_error("wrong! inconsistent dimension"); + } + + return RunAvgPoolBwd(arg); + } + + float Run(const device::BaseArgument* p_arg, + const StreamConfig& /* stream_config */ = StreamConfig{}) override + { + return Run(*dynamic_cast(p_arg)); + } + }; + + static constexpr bool IsValidCompilationParameter() + { + // TODO: properly implement this check + return true; + } + + bool IsSupportedArgument(const device::BaseArgument*) override { return true; } + + static auto MakeArgument(Tensor& dinput, + const Tensor& doutput, + std::vector window_spatial_lengths, + std::vector window_strides, + std::vector window_dilations, + std::vector dinput_left_pads, + std::vector dinput_right_pads) + { + if(window_spatial_lengths.size() != NDimSpatial || window_strides.size() != NDimSpatial || + window_dilations.size() != NDimSpatial || dinput_left_pads.size() != NDimSpatial || + dinput_right_pads.size() != NDimSpatial) + throw std::runtime_error("dimension is incorrect"); + + return Argument{dinput, + doutput, + window_spatial_lengths, + window_strides, + window_dilations, + dinput_left_pads, + dinput_right_pads}; + } + + static auto MakeInvoker() { return Invoker{}; } + + virtual std::unique_ptr MakeInvokerPointer() + { + return std::make_unique(Invoker{}); + } + + std::string GetTypeString() const override + { + auto str = std::stringstream(); + + // clang-format off + str << "ReferenceAvgPoolBwd" + << std::endl; + // clang-format on + + return str.str(); + } +}; + +} // namespace host +} // namespace tensor_operation +} // namespace ck diff --git a/library/include/ck/library/reference_tensor_operation/cpu/reference_batched_gemm.hpp b/library/include/ck/library/reference_tensor_operation/cpu/reference_batched_gemm.hpp index 46a1fa559a1a3d164acc99e536f7e26a83b97ed4..a1b1e0d91b472e5531097b13d57d6efe5b9591a7 100644 --- a/library/include/ck/library/reference_tensor_operation/cpu/reference_batched_gemm.hpp +++ b/library/include/ck/library/reference_tensor_operation/cpu/reference_batched_gemm.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/library/include/ck/library/reference_tensor_operation/cpu/reference_batchnorm_backward.hpp b/library/include/ck/library/reference_tensor_operation/cpu/reference_batchnorm_backward.hpp index 0b621e88a0c6e890bf1efb026615fcc63fadf6c9..a2eabdf5c1c92375503fc8cc961a2cde8410204e 100644 --- a/library/include/ck/library/reference_tensor_operation/cpu/reference_batchnorm_backward.hpp +++ b/library/include/ck/library/reference_tensor_operation/cpu/reference_batchnorm_backward.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/library/include/ck/library/reference_tensor_operation/cpu/reference_batchnorm_forward.hpp b/library/include/ck/library/reference_tensor_operation/cpu/reference_batchnorm_forward.hpp index dd0db316804fd191ee73d9952f6df6779420910d..20c1fcd7367838022d7a4e2776f436ce796460fe 100644 --- a/library/include/ck/library/reference_tensor_operation/cpu/reference_batchnorm_forward.hpp +++ b/library/include/ck/library/reference_tensor_operation/cpu/reference_batchnorm_forward.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/library/include/ck/library/reference_tensor_operation/cpu/reference_batchnorm_infer.hpp b/library/include/ck/library/reference_tensor_operation/cpu/reference_batchnorm_infer.hpp index 463c655ac1d1414f80f9d8b7fcfc696be6326018..7d652fe4c41d044747c2c40a3d6e7383d4f1ff2e 100644 --- a/library/include/ck/library/reference_tensor_operation/cpu/reference_batchnorm_infer.hpp +++ b/library/include/ck/library/reference_tensor_operation/cpu/reference_batchnorm_infer.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/library/include/ck/library/reference_tensor_operation/cpu/reference_cgemm.hpp b/library/include/ck/library/reference_tensor_operation/cpu/reference_cgemm.hpp index b0149d88fdb062806d3b34b400f9b6d3389419f9..24f754e5987adfa82f4b0cc0e41dc85431f94e12 100644 --- a/library/include/ck/library/reference_tensor_operation/cpu/reference_cgemm.hpp +++ b/library/include/ck/library/reference_tensor_operation/cpu/reference_cgemm.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/library/include/ck/library/reference_tensor_operation/cpu/reference_conv_bwd_data.hpp b/library/include/ck/library/reference_tensor_operation/cpu/reference_conv_bwd_data.hpp index 225f7b7e36f3a5a0d70058ed7089fc693a9e0493..50040a2441b6ec5f513f967ebe8fc98a3e5edeeb 100644 --- a/library/include/ck/library/reference_tensor_operation/cpu/reference_conv_bwd_data.hpp +++ b/library/include/ck/library/reference_tensor_operation/cpu/reference_conv_bwd_data.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #pragma once @@ -125,7 +125,7 @@ struct ReferenceConvBwdData : public device::BaseOperator arg.in_element_op_(v_in, v_acc); - arg.input_(g, n, c, wi) = ck::type_convert(v_acc); + arg.input_(g, n, c, wi) = ck::type_convert(v_in); }; make_ParallelTensorFunctor(f_ncw, @@ -201,7 +201,7 @@ struct ReferenceConvBwdData : public device::BaseOperator arg.in_element_op_(v_in, v_acc); - arg.input_(g, n, c, hi, wi) = ck::type_convert(v_acc); + arg.input_(g, n, c, hi, wi) = ck::type_convert(v_in); }; make_ParallelTensorFunctor(f_nchw, @@ -299,7 +299,7 @@ struct ReferenceConvBwdData : public device::BaseOperator arg.in_element_op_(v_in, v_acc); - arg.input_(g, n, c, di, hi, wi) = ck::type_convert(v_acc); + arg.input_(g, n, c, di, hi, wi) = ck::type_convert(v_in); }; make_ParallelTensorFunctor(f_ncdhw, diff --git a/library/include/ck/library/reference_tensor_operation/cpu/reference_conv_bwd_weight.hpp b/library/include/ck/library/reference_tensor_operation/cpu/reference_conv_bwd_weight.hpp index 7d62158f00c0aa28fdfe90d4a9df537367b22717..ec5df238ab2c2a292d8d44951f67e2d8d2383d1a 100644 --- a/library/include/ck/library/reference_tensor_operation/cpu/reference_conv_bwd_weight.hpp +++ b/library/include/ck/library/reference_tensor_operation/cpu/reference_conv_bwd_weight.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/library/include/ck/library/reference_tensor_operation/cpu/reference_conv_fwd.hpp b/library/include/ck/library/reference_tensor_operation/cpu/reference_conv_fwd.hpp index b8d47d218b95e3ff2b38c2fc965b6bdc6a2afc68..8f4182a2318017dc7239dcc32e34af5287e97b4d 100644 --- a/library/include/ck/library/reference_tensor_operation/cpu/reference_conv_fwd.hpp +++ b/library/include/ck/library/reference_tensor_operation/cpu/reference_conv_fwd.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/library/include/ck/library/reference_tensor_operation/cpu/reference_conv_fwd_bias_activation.hpp b/library/include/ck/library/reference_tensor_operation/cpu/reference_conv_fwd_bias_activation.hpp index be22003fd90db20eaea80045bd34f908ed8ebb61..71c84a1f5cd43811105a4630e31ade8ae0b3a075 100644 --- a/library/include/ck/library/reference_tensor_operation/cpu/reference_conv_fwd_bias_activation.hpp +++ b/library/include/ck/library/reference_tensor_operation/cpu/reference_conv_fwd_bias_activation.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/library/include/ck/library/reference_tensor_operation/cpu/reference_conv_fwd_bias_activation_add.hpp b/library/include/ck/library/reference_tensor_operation/cpu/reference_conv_fwd_bias_activation_add.hpp index f949f27fde973cc44eec3162c7625980a9818116..0b90b4b50eec00cbd4d807fe2f43b9280cd2bdf2 100644 --- a/library/include/ck/library/reference_tensor_operation/cpu/reference_conv_fwd_bias_activation_add.hpp +++ b/library/include/ck/library/reference_tensor_operation/cpu/reference_conv_fwd_bias_activation_add.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/library/include/ck/library/reference_tensor_operation/cpu/reference_gemm.hpp b/library/include/ck/library/reference_tensor_operation/cpu/reference_gemm.hpp index be69f297b246798118688abed92c19cf7f988b63..309b4afad8d6df61fd66b49f10f412408f7eca8f 100644 --- a/library/include/ck/library/reference_tensor_operation/cpu/reference_gemm.hpp +++ b/library/include/ck/library/reference_tensor_operation/cpu/reference_gemm.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #pragma once @@ -92,11 +92,11 @@ struct ReferenceGemm : public device::BaseOperator ck::type_convert(v_a) * ck::type_convert(v_b); } - AccDataType v_c; + CDataType v_c; arg.c_element_op_(v_c, v_acc); - arg.c_m_n_(m, n) = ck::type_convert(v_c); + arg.c_m_n_(m, n) = v_c; }; make_ParallelTensorFunctor( diff --git a/library/include/ck/library/reference_tensor_operation/cpu/reference_gemm_layernorm.hpp b/library/include/ck/library/reference_tensor_operation/cpu/reference_gemm_layernorm.hpp index 28132aa1ebd78178774caba73b91ba3c947b66be..ce2a83da6163fb6b19e30121a48734096857c508 100644 --- a/library/include/ck/library/reference_tensor_operation/cpu/reference_gemm_layernorm.hpp +++ b/library/include/ck/library/reference_tensor_operation/cpu/reference_gemm_layernorm.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/library/include/ck/library/reference_tensor_operation/cpu/reference_groupnorm.hpp b/library/include/ck/library/reference_tensor_operation/cpu/reference_groupnorm.hpp index fedd4dce62cdcd31c2c582ab034291e9e098bc3c..6a48528c543b3ad1bcf94f92effcacbce40a8151 100644 --- a/library/include/ck/library/reference_tensor_operation/cpu/reference_groupnorm.hpp +++ b/library/include/ck/library/reference_tensor_operation/cpu/reference_groupnorm.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/library/include/ck/library/reference_tensor_operation/cpu/reference_image_to_column.hpp b/library/include/ck/library/reference_tensor_operation/cpu/reference_image_to_column.hpp new file mode 100644 index 0000000000000000000000000000000000000000..3f50ab88b3352a77594cb8dc1f44ebd7c38d924e --- /dev/null +++ b/library/include/ck/library/reference_tensor_operation/cpu/reference_image_to_column.hpp @@ -0,0 +1,325 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include +#include +#include + +#include "ck/tensor_operation/gpu/device/device_base.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/library/utility/host_tensor.hpp" + +namespace ck { +namespace tensor_operation { +namespace host { + +/** + * \brief Reference implementation for image to column. + * + * Tensor descriptor has [G, N, C, Di, Hi, Wi] data layout. + * G must be equal to 1. Memory layout is [G, N, Di, Hi, Wi, C]. + * + * \tparam NDimSpatial Number of spatial dimensions. + * \tparam InputLayout Input Layout. + * \tparam InDataType Input Data Type. + * \tparam OutDataType Output Data Type. + */ +template = 1 && NDimSpatial <= 3, bool>::type = false> +struct ReferenceImageToColumn : public device::BaseOperator +{ + // Argument + struct Argument : public device::BaseArgument + { + public: + Argument(const Tensor& input, + Tensor& output, + std::vector filter_spatial_lengths, + std::vector conv_filter_strides, + std::vector conv_filter_dilations, + std::vector input_left_pads, + std::vector input_right_pads) + : input_{input}, + output_{output}, + conv_strides_{conv_filter_strides}, + conv_dilations_{conv_filter_dilations}, + in_left_pads_{input_left_pads}, + in_right_pads_{input_right_pads}, + filter_spatial_lengths_{filter_spatial_lengths} + { + initOutputSpatialLengths(); + } + + const Tensor& input_; + Tensor& output_; + + std::vector conv_strides_; + std::vector conv_dilations_; + std::vector in_left_pads_; + std::vector in_right_pads_; + + std::vector filter_spatial_lengths_; + std::vector output_spatial_lengths_; + + private: + void initOutputSpatialLengths() + { + constexpr auto input_offset_to_spatial = 3; + + for(ck::index_t i = 0; i < NDimSpatial; ++i) + { + // XEff = (X - 1) * conv_dilation_w + 1; + // Wo = (Wi + in_left_pad_w + in_right_pad_w - XEff) / conv_stride_w + 1; + const ck::index_t x_eff = (filter_spatial_lengths_[i] - 1) * conv_dilations_[i] + 1; + + output_spatial_lengths_.push_back( + (input_.GetLengths()[i + input_offset_to_spatial] + in_left_pads_[i] + + in_right_pads_[i] - x_eff) / + conv_strides_[i] + + 1); + } + } + }; + + struct Invoker : public device::BaseInvoker + { + using Argument = ReferenceImageToColumn::Argument; + + float Run(const Argument& arg) + { + if(!(arg.input_.GetNumOfDimension() == NDimSpatial + 3 && + arg.output_.GetNumOfDimension() == 2)) + { + throw std::runtime_error("wrong! inconsistent dimension"); + } + + const index_t N = arg.input_.GetLengths()[1]; + const index_t C = arg.input_.GetLengths()[2]; + + if constexpr(NDimSpatial == 1) + { + const index_t Wo = arg.output_spatial_lengths_[0]; + auto func = [&](auto n, auto wo) { + index_t row = n * Wo + wo; + index_t column = 0; + + for(index_t x = 0; x < arg.filter_spatial_lengths_[0]; ++x) + { + auto wi = static_cast(wo * arg.conv_strides_[0]) + + static_cast(x * arg.conv_dilations_[0]) - + static_cast(arg.in_left_pads_[0]); + + for(index_t c = 0; c < C; ++c) + { + if(wi >= 0 && + ck::type_convert(wi) < arg.input_.GetLengths()[3]) + { + InDataType v_in = arg.input_(0, n, c, wi); + arg.output_(row, column) = ck::type_convert(v_in); + } + column++; + } + } + }; + + make_ParallelTensorFunctor(func, N, Wo)(std::thread::hardware_concurrency()); + + return 0; + } + else if constexpr(NDimSpatial == 2) + { + const index_t Ho = arg.output_spatial_lengths_[0]; + const index_t Wo = arg.output_spatial_lengths_[1]; + + auto func = [&](auto n, auto ho, auto wo) { + index_t row = n * Ho * Wo + ho * Wo + wo; + index_t column = 0; + + for(index_t y = 0; y < arg.filter_spatial_lengths_[0]; ++y) + { + auto hi = static_cast(ho * arg.conv_strides_[0]) + + static_cast(y * arg.conv_dilations_[0]) - + static_cast(arg.in_left_pads_[0]); + + for(index_t x = 0; x < arg.filter_spatial_lengths_[1]; ++x) + { + auto wi = static_cast(wo * arg.conv_strides_[1]) + + static_cast(x * arg.conv_dilations_[1]) - + static_cast(arg.in_left_pads_[1]); + + for(index_t c = 0; c < C; ++c) + { + + if(hi >= 0 && + ck::type_convert(hi) < arg.input_.GetLengths()[3] && + wi >= 0 && + ck::type_convert(wi) < arg.input_.GetLengths()[4]) + { + InDataType v_in = arg.input_(0, n, c, hi, wi); + arg.output_(row, column) = ck::type_convert(v_in); + } + column++; + } + } + } + }; + + make_ParallelTensorFunctor(func, N, Ho, Wo)(std::thread::hardware_concurrency()); + + return 0; + } + else if constexpr(NDimSpatial == 3) + { + const index_t Do = arg.output_spatial_lengths_[0]; + const index_t Ho = arg.output_spatial_lengths_[1]; + const index_t Wo = arg.output_spatial_lengths_[2]; + + auto func = [&](auto n, auto d_o, auto ho, auto wo) { + index_t row = n * Do * Ho * Wo + d_o * Ho * Wo + ho * Wo + wo; + index_t column = 0; + + for(index_t z = 0; z < arg.filter_spatial_lengths_[0]; ++z) + { + auto di = static_cast(d_o * arg.conv_strides_[0]) + + static_cast(z * arg.conv_dilations_[0]) - + static_cast(arg.in_left_pads_[0]); + for(index_t y = 0; y < arg.filter_spatial_lengths_[1]; ++y) + { + auto hi = static_cast(ho * arg.conv_strides_[1]) + + static_cast(y * arg.conv_dilations_[1]) - + static_cast(arg.in_left_pads_[1]); + for(index_t x = 0; x < arg.filter_spatial_lengths_[2]; ++x) + { + auto wi = + static_cast(wo * arg.conv_strides_[2]) + + static_cast(x * arg.conv_dilations_[2]) - + static_cast(arg.in_left_pads_[2]); + for(index_t c = 0; c < C; ++c) + { + if(di >= 0 && + ck::type_convert(di) < + arg.input_.GetLengths()[3] && + hi >= 0 && + ck::type_convert(hi) < + arg.input_.GetLengths()[4] && + wi >= 0 && + ck::type_convert(wi) < + arg.input_.GetLengths()[5]) + { + InDataType v_in = arg.input_(0, n, c, di, hi, wi); + arg.output_(row, column) = + ck::type_convert(v_in); + } + column++; + } + } + } + } + }; + + make_ParallelTensorFunctor(func, N, Do, Ho, Wo)( + std::thread::hardware_concurrency()); + + return 0; + } + } + + float Run(const device::BaseArgument* p_arg, + const StreamConfig& /*stream_config*/ = StreamConfig{}) override + { + return Run(*dynamic_cast(p_arg)); + } + }; + + static constexpr bool IsValidCompilationParameter() + { + using namespace tensor_layout::convolution; + + if constexpr(!(std::is_same_v || std::is_same_v || + std::is_same_v)) + { + return false; + } + if constexpr(!(NDimSpatial >= 1 && NDimSpatial <= 3)) + { + return false; + } + return true; + } + + bool IsSupportedArgument(const Argument& arg) + { + const ck::index_t G = arg.input_.GetLengths()[0]; + const ck::index_t N = arg.input_.GetLengths()[1]; + const ck::index_t C = arg.input_.GetLengths()[2]; + + const index_t NDoHoWo = + N * ck::accumulate_n( + arg.output_spatial_lengths_.begin(), NDimSpatial, 1, std::multiplies<>()); + const index_t CZYX = + C * ck::accumulate_n( + arg.filter_spatial_lengths_.begin(), NDimSpatial, 1, std::multiplies<>()); + + if(!(arg.output_.GetLengths()[0] == static_cast(NDoHoWo) && + arg.output_.GetLengths()[1] == static_cast(CZYX))) + { + return false; + } + + if(G != 1) + { + return false; + } + return true; + } + + bool IsSupportedArgument(const device::BaseArgument* p_arg) override + { + return IsSupportedArgument(*dynamic_cast(p_arg)); + } + + static auto MakeArgument(const Tensor& input, + Tensor& output, + std::vector filter_spatial_lengths, + std::vector conv_filter_strides, + std::vector conv_filter_dilations, + std::vector input_left_pads, + std::vector input_right_pads) + { + return Argument{input, + output, + filter_spatial_lengths, + conv_filter_strides, + conv_filter_dilations, + input_left_pads, + input_right_pads}; + } + + static auto MakeInvoker() { return Invoker{}; } + + virtual std::unique_ptr MakeInvokerPointer() + { + return std::make_unique(Invoker{}); + } + + std::string GetTypeString() const override + { + auto str = std::stringstream(); + + // clang-format off + str << "ReferenceImageToColumn" + << std::endl; + // clang-format on + + return str.str(); + } +}; + +} // namespace host +} // namespace tensor_operation +} // namespace ck diff --git a/library/include/ck/library/reference_tensor_operation/cpu/reference_layernorm.hpp b/library/include/ck/library/reference_tensor_operation/cpu/reference_layernorm.hpp index 2bac5bc5c8f5cd6f1629e8cd108253fc6e9fee01..9994a2f9f7c862a1bf870103f999634e17682ecf 100644 --- a/library/include/ck/library/reference_tensor_operation/cpu/reference_layernorm.hpp +++ b/library/include/ck/library/reference_tensor_operation/cpu/reference_layernorm.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/library/include/ck/library/reference_tensor_operation/cpu/reference_maxpool_bwd.hpp b/library/include/ck/library/reference_tensor_operation/cpu/reference_maxpool_bwd.hpp new file mode 100644 index 0000000000000000000000000000000000000000..60c74fbf14acdce91d290d1d28ee99ad6c77887e --- /dev/null +++ b/library/include/ck/library/reference_tensor_operation/cpu/reference_maxpool_bwd.hpp @@ -0,0 +1,112 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include +#include +#include + +#include "ck/tensor_operation/gpu/device/device_base.hpp" +#include "ck/library/utility/host_tensor.hpp" +#include "ck/library/utility/host_tensor_generator.hpp" + +namespace ck { +namespace tensor_operation { +namespace host { +using namespace std; + +template +struct ReferenceMaxPoolBwd : public device::BaseOperator +{ + // Argument + struct Argument : public device::BaseArgument + { + Argument(const Tensor& dout, + const Tensor& indices, + Tensor& din, + ElementwiseOperation elementwise_op) + : dout_(dout), indices_(indices), din_(din), elementwise_op_(elementwise_op) + { + } + + const Tensor& dout_; + const Tensor& indices_; + Tensor& din_; + ElementwiseOperation elementwise_op_; + }; + + // Invoker + struct Invoker : public device::BaseInvoker + { + float Run(const Argument& arg) + { + int din_length = arg.din_.GetElementSpaceSize(); + int dout_length = arg.dout_.GetElementSpaceSize(); + std::vector buf(din_length, 0); + + for(int i = 0; i < dout_length; ++i) + { + int index = arg.indices_.mData[i]; + if(index >= 0 && index < din_length) + { + if constexpr(is_same_v) + { + float buf_val = ck::type_convert(buf[index]); + buf_val += ck::type_convert(arg.dout_.mData[i]); + buf[index] = ck::type_convert(buf_val); + } + else + buf[index] += ck::type_convert(arg.dout_.mData[i]); + } + } + + for(int i = 0; i < din_length; ++i) + arg.din_.mData[i] = ck::type_convert(buf[i]); + return 0; + } + + float Run(const device::BaseArgument* p_arg, + const StreamConfig& /* stream_config */ = StreamConfig{}) override + { + return Run(*dynamic_cast(p_arg)); + } + }; + + bool IsSupportedArgument(const device::BaseArgument*) override { return true; } + + static auto MakeArgument(const Tensor& dout, + const Tensor& indices, + Tensor& din, + ElementwiseOperation elementwise_op) + { + return Argument{dout, indices, din, elementwise_op}; + } + + static auto MakeInvoker() { return Invoker{}; } + + virtual std::unique_ptr MakeInvokerPointer() + { + return std::make_unique(Invoker{}); + } + + std::string GetTypeString() const override + { + auto str = std::stringstream(); + + // clang-format off + str << "ReferenceMaxPoolBwd" + << std::endl; + // clang-format on + + return str.str(); + } +}; + +} // namespace host +} // namespace tensor_operation +} // namespace ck diff --git a/library/include/ck/library/reference_tensor_operation/cpu/reference_pool_fwd.hpp b/library/include/ck/library/reference_tensor_operation/cpu/reference_pool_fwd.hpp index 3fc35a83c727f8eaad450eea92ab004a24a24d6e..cf241ac1b107ca4f4527ab6bfc147cbf3773d6b5 100644 --- a/library/include/ck/library/reference_tensor_operation/cpu/reference_pool_fwd.hpp +++ b/library/include/ck/library/reference_tensor_operation/cpu/reference_pool_fwd.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #pragma once @@ -39,6 +39,7 @@ struct ReferencePoolingFwd : public device::BaseOperator Tensor& out_indices, const std::vector& window_spatial_lengths, const std::vector& window_strides, + const std::vector& window_dilations, const std::vector& in_left_pads, const std::vector& /*in_right_pads*/) : in_(in), @@ -46,6 +47,7 @@ struct ReferencePoolingFwd : public device::BaseOperator out_indices_(out_indices), window_spatial_lengths_(window_spatial_lengths), window_strides_(window_strides), + window_dilations_(window_dilations), in_left_pads_(in_left_pads), reduceLength_(1) { @@ -58,6 +60,7 @@ struct ReferencePoolingFwd : public device::BaseOperator Tensor& out_indices_; const std::vector& window_spatial_lengths_; const std::vector& window_strides_; + const std::vector& window_dilations_; const std::vector& in_left_pads_; int reduceLength_; }; @@ -85,14 +88,17 @@ struct ReferencePoolingFwd : public device::BaseOperator for(ck::index_t z = 0; z < arg.window_spatial_lengths_[0]; ++z) { - ck::index_t di = do_ * arg.window_strides_[0] + z - arg.in_left_pads_[0]; + ck::index_t di = do_ * arg.window_strides_[0] + + z * arg.window_dilations_[0] - arg.in_left_pads_[0]; for(ck::index_t y = 0; y < arg.window_spatial_lengths_[1]; ++y) { - ck::index_t hi = ho * arg.window_strides_[1] + y - arg.in_left_pads_[1]; + ck::index_t hi = ho * arg.window_strides_[1] + + y * arg.window_dilations_[1] - arg.in_left_pads_[1]; for(ck::index_t x = 0; x < arg.window_spatial_lengths_[2]; ++x) { - ck::index_t wi = - wo * arg.window_strides_[2] + x - arg.in_left_pads_[2]; + ck::index_t wi = wo * arg.window_strides_[2] + + x * arg.window_dilations_[2] - + arg.in_left_pads_[2]; if(di >= 0 && di < static_cast(arg.in_.mDesc.GetLengths()[2]) && hi >= 0 && @@ -100,8 +106,8 @@ struct ReferencePoolingFwd : public device::BaseOperator wi >= 0 && wi < static_cast(arg.in_.mDesc.GetLengths()[4])) { - ComputeDataType currVal = - static_cast(arg.in_(n, c, di, hi, wi)); + ComputeDataType currVal = ck::type_convert( + arg.in_(n, c, di, hi, wi)); in_elementwise_op(currVal, currVal); @@ -112,7 +118,7 @@ struct ReferencePoolingFwd : public device::BaseOperator } acc_elementwise_op(accuVal, accuVal); - arg.out_(n, c, do_, ho, wo) = accuVal; + arg.out_(n, c, do_, ho, wo) = ck::type_convert(accuVal); }; make_ParallelTensorFunctor(f_ncdhw, @@ -136,14 +142,17 @@ struct ReferencePoolingFwd : public device::BaseOperator for(ck::index_t z = 0; z < arg.window_spatial_lengths_[0]; ++z) { - ck::index_t di = do_ * arg.window_strides_[0] + z - arg.in_left_pads_[0]; + ck::index_t di = do_ * arg.window_strides_[0] + + z * arg.window_dilations_[0] - arg.in_left_pads_[0]; for(ck::index_t y = 0; y < arg.window_spatial_lengths_[1]; ++y) { - ck::index_t hi = ho * arg.window_strides_[1] + y - arg.in_left_pads_[1]; + ck::index_t hi = ho * arg.window_strides_[1] + + y * arg.window_dilations_[1] - arg.in_left_pads_[1]; for(ck::index_t x = 0; x < arg.window_spatial_lengths_[2]; ++x) { - ck::index_t wi = - wo * arg.window_strides_[2] + x - arg.in_left_pads_[2]; + ck::index_t wi = wo * arg.window_strides_[2] + + x * arg.window_dilations_[2] - + arg.in_left_pads_[2]; if(di >= 0 && di < static_cast(arg.in_.mDesc.GetLengths()[2]) && hi >= 0 && @@ -151,8 +160,8 @@ struct ReferencePoolingFwd : public device::BaseOperator wi >= 0 && wi < static_cast(arg.in_.mDesc.GetLengths()[4])) { - ComputeDataType currVal = - static_cast(arg.in_(n, c, di, hi, wi)); + ComputeDataType currVal = ck::type_convert( + arg.in_(n, c, di, hi, wi)); IndexDataType currIndex = arg.in_.GetOffsetFromMultiIndex(n, c, di, hi, wi); @@ -166,7 +175,7 @@ struct ReferencePoolingFwd : public device::BaseOperator acc_elementwise_op(accuVal, accuVal); - arg.out_(n, c, do_, ho, wo) = accuVal; + arg.out_(n, c, do_, ho, wo) = ck::type_convert(accuVal); arg.out_indices_(n, c, do_, ho, wo) = accuIndex; }; @@ -202,17 +211,19 @@ struct ReferencePoolingFwd : public device::BaseOperator for(ck::index_t y = 0; y < arg.window_spatial_lengths_[0]; ++y) { - ck::index_t hi = ho * arg.window_strides_[0] + y - arg.in_left_pads_[0]; + ck::index_t hi = ho * arg.window_strides_[0] + + y * arg.window_dilations_[0] - arg.in_left_pads_[0]; for(ck::index_t x = 0; x < arg.window_spatial_lengths_[1]; ++x) { - ck::index_t wi = wo * arg.window_strides_[1] + x - arg.in_left_pads_[1]; + ck::index_t wi = wo * arg.window_strides_[1] + + x * arg.window_dilations_[1] - arg.in_left_pads_[1]; if(hi >= 0 && hi < static_cast(arg.in_.mDesc.GetLengths()[2]) && wi >= 0 && wi < static_cast(arg.in_.mDesc.GetLengths()[3])) { ComputeDataType currVal = - static_cast(arg.in_(n, c, hi, wi)); + ck::type_convert(arg.in_(n, c, hi, wi)); in_elementwise_op(currVal, currVal); @@ -222,7 +233,7 @@ struct ReferencePoolingFwd : public device::BaseOperator } acc_elementwise_op(accuVal, accuVal); - arg.out_(n, c, ho, wo) = accuVal; + arg.out_(n, c, ho, wo) = ck::type_convert(accuVal); }; make_ParallelTensorFunctor(f_nchw, @@ -245,17 +256,19 @@ struct ReferencePoolingFwd : public device::BaseOperator for(ck::index_t y = 0; y < arg.window_spatial_lengths_[0]; ++y) { - ck::index_t hi = ho * arg.window_strides_[0] + y - arg.in_left_pads_[0]; + ck::index_t hi = ho * arg.window_strides_[0] + + y * arg.window_dilations_[0] - arg.in_left_pads_[0]; for(ck::index_t x = 0; x < arg.window_spatial_lengths_[1]; ++x) { - ck::index_t wi = wo * arg.window_strides_[1] + x - arg.in_left_pads_[1]; + ck::index_t wi = wo * arg.window_strides_[1] + + x * arg.window_dilations_[1] - arg.in_left_pads_[1]; if(hi >= 0 && hi < static_cast(arg.in_.mDesc.GetLengths()[2]) && wi >= 0 && wi < static_cast(arg.in_.mDesc.GetLengths()[3])) { ComputeDataType currVal = - static_cast(arg.in_(n, c, hi, wi)); + ck::type_convert(arg.in_(n, c, hi, wi)); IndexDataType currIndex = arg.in_.GetOffsetFromMultiIndex(n, c, hi, wi); @@ -268,7 +281,7 @@ struct ReferencePoolingFwd : public device::BaseOperator } acc_elementwise_op(accuVal, accuVal); - arg.out_(n, c, ho, wo) = accuVal; + arg.out_(n, c, ho, wo) = ck::type_convert(accuVal); arg.out_indices_(n, c, ho, wo) = accuIndex; }; @@ -308,6 +321,7 @@ struct ReferencePoolingFwd : public device::BaseOperator Tensor& out_indices, const std::vector& window_spatial_lengths, const std::vector& window_strides, + const std::vector& window_dilations, const std::vector& in_left_pads, const std::vector& in_right_pads) { @@ -316,6 +330,7 @@ struct ReferencePoolingFwd : public device::BaseOperator out_indices, window_spatial_lengths, window_strides, + window_dilations, in_left_pads, in_right_pads}; } diff --git a/library/include/ck/library/reference_tensor_operation/cpu/reference_reduce.hpp b/library/include/ck/library/reference_tensor_operation/cpu/reference_reduce.hpp index c04baca57491ddce5f4dd09469ddb116840bcecd..944f34007ef0e030e2017688477c27602df15391 100644 --- a/library/include/ck/library/reference_tensor_operation/cpu/reference_reduce.hpp +++ b/library/include/ck/library/reference_tensor_operation/cpu/reference_reduce.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/library/include/ck/library/reference_tensor_operation/cpu/reference_softmax.hpp b/library/include/ck/library/reference_tensor_operation/cpu/reference_softmax.hpp index a4fd46c932c08f3d51df8db212e167d66b755c27..9916a03b9c8f2eabe20ef0c27d78def6f4b9d6a5 100644 --- a/library/include/ck/library/reference_tensor_operation/cpu/reference_softmax.hpp +++ b/library/include/ck/library/reference_tensor_operation/cpu/reference_softmax.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/library/include/ck/library/reference_tensor_operation/cpu/reference_sparse_embedding3_forward_layernorm.hpp b/library/include/ck/library/reference_tensor_operation/cpu/reference_sparse_embedding3_forward_layernorm.hpp index b6a9b0fb5ee4a3135750dad0d0818a0f6e439f3a..f949260ca4552b83ab6e8edafee9190d4f95cc64 100644 --- a/library/include/ck/library/reference_tensor_operation/cpu/reference_sparse_embedding3_forward_layernorm.hpp +++ b/library/include/ck/library/reference_tensor_operation/cpu/reference_sparse_embedding3_forward_layernorm.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/library/include/ck/library/reference_tensor_operation/gpu/naive_conv_fwd.hpp b/library/include/ck/library/reference_tensor_operation/gpu/naive_conv_fwd.hpp index df4fca6562755012200ff888677bcac1c90129a2..0b7887efb6e05f34b83a5bf684b5fedd3d78c385 100644 --- a/library/include/ck/library/reference_tensor_operation/gpu/naive_conv_fwd.hpp +++ b/library/include/ck/library/reference_tensor_operation/gpu/naive_conv_fwd.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #ifndef NAIVE_CONV_FWD_HPP #define NAIVE_CONV_FWD_HPP diff --git a/library/include/ck/library/tensor_operation_instance/add_device_operation_instance.hpp b/library/include/ck/library/tensor_operation_instance/add_device_operation_instance.hpp index 20df1b3616a016029f4083bd821b4e286b7b727c..f57fed9c07c7843af0f77db7991ce077b3021245 100644 --- a/library/include/ck/library/tensor_operation_instance/add_device_operation_instance.hpp +++ b/library/include/ck/library/tensor_operation_instance/add_device_operation_instance.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/library/include/ck/library/tensor_operation_instance/device_operation_instance_factory.hpp b/library/include/ck/library/tensor_operation_instance/device_operation_instance_factory.hpp index 605e83c2cd86fcb72363e4233efb5dc7661998ee..84d31ce2675b88706b843cbe96ad320caea8575c 100644 --- a/library/include/ck/library/tensor_operation_instance/device_operation_instance_factory.hpp +++ b/library/include/ck/library/tensor_operation_instance/device_operation_instance_factory.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #pragma once @@ -17,6 +17,7 @@ namespace instance { using F64 = double; using F32 = float; using F16 = ck::half_t; +using F8 = ck::f8_t; using BF16 = ck::bhalf_t; using I8 = int8_t; using I32 = int32_t; @@ -32,6 +33,8 @@ using I32_Tuple = ck::Tuple; using I32_F32_Tuple = ck::Tuple; using I8_Tuple = ck::Tuple; +using F32_F32_Tuple = ck::Tuple; + // GEMM layout using Row = ck::tensor_layout::gemm::RowMajor; using Col = ck::tensor_layout::gemm::ColumnMajor; @@ -95,9 +98,11 @@ using AddFastGelu = ck::tensor_operation::element_wise::AddFastGelu; using AddReluAdd = ck::tensor_operation::element_wise::AddReluAdd; using FastGelu = ck::tensor_operation::element_wise::FastGelu; using AddMultiply = ck::tensor_operation::element_wise::AddMultiply; +using MultiplyAdd = ck::tensor_operation::element_wise::MultiplyAdd; using ScaleAdd = ck::tensor_operation::element_wise::ScaleAdd; using Gelu = ck::tensor_operation::element_wise::Gelu; using Swish = ck::tensor_operation::element_wise::Swish; +using Add = ck::tensor_operation::element_wise::Add; template using Activation_Mul_Clamp = ck::tensor_operation::element_wise::Activation_Mul_Clamp; diff --git a/library/include/ck/library/tensor_operation_instance/gpu/avg_pool3d_bwd.hpp b/library/include/ck/library/tensor_operation_instance/gpu/avg_pool3d_bwd.hpp new file mode 100644 index 0000000000000000000000000000000000000000..949e1d2dd071a7b9f9f68eea81e4f168ca9a6bc0 --- /dev/null +++ b/library/include/ck/library/tensor_operation_instance/gpu/avg_pool3d_bwd.hpp @@ -0,0 +1,59 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck/tensor_operation/gpu/device/device_avgpool_bwd.hpp" +#include "ck/library/tensor_operation_instance/device_operation_instance_factory.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +#ifdef CK_ENABLE_FP16 +void add_device_avgpool_bwd_ndhwc_f16_instances( + std::vector>>&); +#endif +#ifdef CK_ENABLE_BF16 +void add_device_avgpool_bwd_ndhwc_bf16_instances( + std::vector>>&); +#endif +#ifdef CK_ENABLE_FP32 +void add_device_avgpool_bwd_ndhwc_f32_instances( + std::vector>>&); +#endif +template +struct DeviceOperationInstanceFactory< + ck::tensor_operation::device:: + DeviceAvgPoolBwd<3, DOutDataType, DInDataType, InLayout, OutLayout>> +{ + using DeviceOp = DeviceAvgPoolBwd<3, DOutDataType, DInDataType, InLayout, OutLayout>; + + static auto GetInstances() + { + std::vector> op_ptrs; + if constexpr(is_same_v && is_same_v) + { +#ifdef CK_ENABLE_FP16 + if constexpr(is_same_v && is_same_v) + add_device_avgpool_bwd_ndhwc_f16_instances(op_ptrs); +#endif +#ifdef CK_ENABLE_BF16 + else if constexpr(is_same_v && is_same_v) + add_device_avgpool_bwd_ndhwc_bf16_instances(op_ptrs); +#endif +#ifdef CK_ENABLE_FP32 + else if constexpr(is_same_v && is_same_v) + add_device_avgpool_bwd_ndhwc_f32_instances(op_ptrs); +#endif + } + + return op_ptrs; + } +}; + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/include/ck/library/tensor_operation_instance/gpu/batched_gemm.hpp b/library/include/ck/library/tensor_operation_instance/gpu/batched_gemm.hpp index bb5f971c779e870aa686a55981ec680933e383de..8f15e80794d8f8b18d8b1194e4a0e1d2bbfd5967 100644 --- a/library/include/ck/library/tensor_operation_instance/gpu/batched_gemm.hpp +++ b/library/include/ck/library/tensor_operation_instance/gpu/batched_gemm.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #pragma once @@ -16,7 +16,7 @@ namespace ck { namespace tensor_operation { namespace device { namespace instance { - +#ifdef CK_ENABLE_BF16 void add_device_batched_gemm_xdl_bf16_bf16_bf16_gkm_gkn_gmn_instances( std::vector>>& @@ -36,7 +36,8 @@ void add_device_batched_gemm_xdl_bf16_bf16_bf16_gmk_gnk_gmn_instances( std::vector>>& instances); - +#endif +#ifdef CK_ENABLE_FP16 void add_device_batched_gemm_xdl_f16_f16_f16_gkm_gkn_gmn_instances( std::vector>>& @@ -56,7 +57,8 @@ void add_device_batched_gemm_xdl_f16_f16_f16_gmk_gnk_gmn_instances( std::vector>>& instances); - +#endif +#ifdef CK_ENABLE_FP32 void add_device_batched_gemm_xdl_f32_f32_f32_gkm_gkn_gmn_instances( std::vector>>& @@ -76,7 +78,8 @@ void add_device_batched_gemm_xdl_f32_f32_f32_gmk_gnk_gmn_instances( std::vector>>& instances); - +#endif +#ifdef CK_ENABLE_INT8 void add_device_batched_gemm_xdl_int8_int8_int8_gkm_gkn_gmn_instances( std::vector>>& instances); - +#endif template > op_ptrs; - +#ifdef CK_ENABLE_FP32 if constexpr(is_same_v && is_same_v && is_same_v) { @@ -176,8 +179,10 @@ struct DeviceOperationInstanceFactory && is_same_v && - is_same_v) +#endif +#ifdef CK_ENABLE_FP16 + if constexpr(is_same_v && is_same_v && + is_same_v) { if constexpr(is_same_v && is_same_v && is_same_v) @@ -200,8 +205,10 @@ struct DeviceOperationInstanceFactory && is_same_v && - is_same_v) +#endif +#ifdef CK_ENABLE_BF16 + if constexpr(is_same_v && is_same_v && + is_same_v) { if constexpr(is_same_v && is_same_v && is_same_v) @@ -224,8 +231,10 @@ struct DeviceOperationInstanceFactory && is_same_v && - is_same_v) +#endif +#ifdef CK_ENABLE_INT8 + if constexpr(is_same_v && is_same_v && + is_same_v) { if constexpr(is_same_v && is_same_v && is_same_v) @@ -248,7 +257,7 @@ struct DeviceOperationInstanceFactory>>& instances); - +#endif +#ifdef CK_ENABLE_BF16 void add_device_batched_gemm_bias_masking_softmax_gemm_permute_xdl_cshuffle_bf16_bf16_bf16_bf16_gmk_gnk_gno_gmo_instances( std::vector>>& instances); - +#endif template > op_ptrs; - +#ifdef CK_ENABLE_FP16 if constexpr(is_same_v && is_same_v && is_same_v && is_same_v && Acc0BiasDataType::Size() == 1 && @@ -164,6 +165,8 @@ struct DeviceOperationInstanceFactory< op_ptrs); } } +#endif +#ifdef CK_ENABLE_BF16 else if constexpr(is_same_v && is_same_v && is_same_v && is_same_v && Acc0BiasDataType::Size() == 1 && @@ -180,6 +183,7 @@ struct DeviceOperationInstanceFactory< op_ptrs); } } +#endif return op_ptrs; } }; diff --git a/library/include/ck/library/tensor_operation_instance/gpu/batched_gemm_gemm.hpp b/library/include/ck/library/tensor_operation_instance/gpu/batched_gemm_gemm.hpp index e1a4391c4a7246377dee190e876aeb2800c5e3fa..77ad36b97b9e65c5fdb1e6fdc72c806321d830a1 100644 --- a/library/include/ck/library/tensor_operation_instance/gpu/batched_gemm_gemm.hpp +++ b/library/include/ck/library/tensor_operation_instance/gpu/batched_gemm_gemm.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #pragma once @@ -16,7 +16,7 @@ namespace ck { namespace tensor_operation { namespace device { namespace instance { - +#ifdef CK_ENABLE_FP16 void add_device_batched_gemm_gemm_xdl_cshuffle_f16_f16_f16_f16_gmk_gnk_gno_gmo_instance( std::vector +#include +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/device_batched_gemm_multi_d.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" + +#include "ck/library/tensor_operation_instance/device_operation_instance_factory.hpp" + +#ifdef DL_KERNELS +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { +#ifdef CK_ENABLE_FP16 +void add_device_batched_gemm_multi_d_dl_f16_f16_f16_gkm_gkn_gmn_instances( + std::vector>>& instances); + +void add_device_batched_gemm_multi_d_dl_f16_f16_f16_gkm_gnk_gmn_instances( + std::vector>>& instances); + +void add_device_batched_gemm_multi_d_dl_f16_f16_f16_gmk_gkn_gmn_instances( + std::vector>>& instances); + +void add_device_batched_gemm_multi_d_dl_f16_f16_f16_gmk_gnk_gmn_instances( + std::vector>>& instances); + +void add_device_batched_gemm_multi_d_dl_f16_f16_f16_gkm_gkn_gmn_irregular_instances( + std::vector>>& instances); + +void add_device_batched_gemm_multi_d_dl_f16_f16_f16_gkm_gnk_gmn_irregular_instances( + std::vector>>& instances); + +void add_device_batched_gemm_multi_d_dl_f16_f16_f16_gmk_gkn_gmn_irregular_instances( + std::vector>>& instances); + +void add_device_batched_gemm_multi_d_dl_f16_f16_f16_gmk_gnk_gmn_irregular_instances( + std::vector>>& instances); +#endif +#ifdef CK_ENABLE_INT8 +void add_device_batched_gemm_multi_d_dl_i8_i8_i8_gkm_gkn_gmn_instances( + std::vector>>& instances); + +void add_device_batched_gemm_multi_d_dl_i8_i8_i8_gkm_gnk_gmn_instances( + std::vector>>& instances); + +void add_device_batched_gemm_multi_d_dl_i8_i8_i8_gmk_gkn_gmn_instances( + std::vector>>& instances); + +void add_device_batched_gemm_multi_d_dl_i8_i8_i8_gmk_gnk_gmn_instances( + std::vector>>& instances); + +void add_device_batched_gemm_multi_d_dl_i8_i8_i8_gkm_gkn_gmn_irregular_instances( + std::vector>>& instances); + +void add_device_batched_gemm_multi_d_dl_i8_i8_i8_gkm_gnk_gmn_irregular_instances( + std::vector>>& instances); + +void add_device_batched_gemm_multi_d_dl_i8_i8_i8_gmk_gkn_gmn_irregular_instances( + std::vector>>& instances); + +void add_device_batched_gemm_multi_d_dl_i8_i8_i8_gmk_gnk_gmn_irregular_instances( + std::vector>>& instances); +#endif +template +struct DeviceOperationInstanceFactory> +{ + using DeviceOp = DeviceBatchedGemmMultiD; + + static auto GetInstances() + { + std::vector> op_ptrs; +#ifdef CK_ENABLE_FP16 + if constexpr(is_same_v && is_same_v && + is_same_v) + { + if constexpr(is_same_v && is_same_v && + is_same_v) + { + add_device_batched_gemm_multi_d_dl_f16_f16_f16_gmk_gkn_gmn_instances(op_ptrs); + add_device_batched_gemm_multi_d_dl_f16_f16_f16_gmk_gkn_gmn_irregular_instances( + op_ptrs); + } + else if constexpr(is_same_v && is_same_v && + is_same_v) + { + add_device_batched_gemm_multi_d_dl_f16_f16_f16_gmk_gnk_gmn_instances(op_ptrs); + add_device_batched_gemm_multi_d_dl_f16_f16_f16_gmk_gnk_gmn_irregular_instances( + op_ptrs); + } + else if constexpr(is_same_v && is_same_v && + is_same_v) + { + add_device_batched_gemm_multi_d_dl_f16_f16_f16_gkm_gkn_gmn_instances(op_ptrs); + add_device_batched_gemm_multi_d_dl_f16_f16_f16_gkm_gkn_gmn_irregular_instances( + op_ptrs); + } + else if constexpr(is_same_v && is_same_v && + is_same_v) + { + add_device_batched_gemm_multi_d_dl_f16_f16_f16_gkm_gnk_gmn_instances(op_ptrs); + add_device_batched_gemm_multi_d_dl_f16_f16_f16_gkm_gnk_gmn_irregular_instances( + op_ptrs); + } + } +#endif +#ifdef CK_ENABLE_INT8 + else if constexpr(is_same_v && is_same_v && + is_same_v) + { + if constexpr(is_same_v && is_same_v && + is_same_v) + { + add_device_batched_gemm_multi_d_dl_i8_i8_i8_gmk_gkn_gmn_instances(op_ptrs); + add_device_batched_gemm_multi_d_dl_i8_i8_i8_gmk_gkn_gmn_irregular_instances( + op_ptrs); + } + else if constexpr(is_same_v && is_same_v && + is_same_v) + { + add_device_batched_gemm_multi_d_dl_i8_i8_i8_gmk_gnk_gmn_instances(op_ptrs); + add_device_batched_gemm_multi_d_dl_i8_i8_i8_gmk_gnk_gmn_irregular_instances( + op_ptrs); + } + else if constexpr(is_same_v && is_same_v && + is_same_v) + { + add_device_batched_gemm_multi_d_dl_i8_i8_i8_gkm_gkn_gmn_instances(op_ptrs); + add_device_batched_gemm_multi_d_dl_i8_i8_i8_gkm_gkn_gmn_irregular_instances( + op_ptrs); + } + else if constexpr(is_same_v && is_same_v && + is_same_v) + { + add_device_batched_gemm_multi_d_dl_i8_i8_i8_gkm_gnk_gmn_instances(op_ptrs); + add_device_batched_gemm_multi_d_dl_i8_i8_i8_gkm_gnk_gmn_irregular_instances( + op_ptrs); + } + } +#endif + return op_ptrs; + } +}; + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck +#endif diff --git a/library/include/ck/library/tensor_operation_instance/gpu/batched_gemm_softmax_gemm.hpp b/library/include/ck/library/tensor_operation_instance/gpu/batched_gemm_softmax_gemm.hpp index 8a0b1b1fa7de54954a5bc375241bb7a0686b2e49..88bdaea23ad35668b3ee547113a0ad5a9086e65b 100644 --- a/library/include/ck/library/tensor_operation_instance/gpu/batched_gemm_softmax_gemm.hpp +++ b/library/include/ck/library/tensor_operation_instance/gpu/batched_gemm_softmax_gemm.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #pragma once @@ -11,7 +11,7 @@ #include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" #include "ck/library/tensor_operation_instance/device_operation_instance_factory.hpp" - +#ifdef CK_ENABLE_FP16 namespace ck { namespace tensor_operation { namespace device { @@ -119,3 +119,4 @@ struct DeviceOperationInstanceFactory< } // namespace device } // namespace tensor_operation } // namespace ck +#endif diff --git a/library/include/ck/library/tensor_operation_instance/gpu/batched_gemm_softmax_gemm_permute.hpp b/library/include/ck/library/tensor_operation_instance/gpu/batched_gemm_softmax_gemm_permute.hpp index 34c86dd44076247e300d7b2ed409a65699652e9b..1814cfabb4c8eb6126a29ad259c49b8d4a2d9e25 100644 --- a/library/include/ck/library/tensor_operation_instance/gpu/batched_gemm_softmax_gemm_permute.hpp +++ b/library/include/ck/library/tensor_operation_instance/gpu/batched_gemm_softmax_gemm_permute.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #pragma once @@ -16,7 +16,7 @@ namespace ck { namespace tensor_operation { namespace device { namespace instance { - +#ifdef CK_ENABLE_FP16 void add_device_batched_gemm_masking_softmax_gemm_permute_xdl_cshuffle_f16_f16_f16_f16_gmk_gnk_gno_gmo_instances( std::vector>>& instances); - +#endif +#ifdef CK_ENABLE_BF16 void add_device_batched_gemm_masking_softmax_gemm_permute_xdl_cshuffle_bf16_bf16_bf16_bf16_gmk_gnk_gno_gmo_instances( std::vector>>& instances); +#endif template > op_ptrs; - +#ifdef CK_ENABLE_FP16 if constexpr(is_same_v && is_same_v && is_same_v && is_same_v) { @@ -161,6 +163,8 @@ struct DeviceOperationInstanceFactory< op_ptrs); } } +#endif +#ifdef CK_ENABLE_BF16 else if constexpr(is_same_v && is_same_v && is_same_v && is_same_v) { @@ -175,6 +179,7 @@ struct DeviceOperationInstanceFactory< op_ptrs); } } +#endif return op_ptrs; } }; diff --git a/library/include/ck/library/tensor_operation_instance/gpu/batchnorm_backward.hpp b/library/include/ck/library/tensor_operation_instance/gpu/batchnorm_backward.hpp index c84ffcff8cbb8d4820f88dbfa3e7de6a04961bb8..0e1f6f04e8c58812651d3d549601952050fb73dd 100644 --- a/library/include/ck/library/tensor_operation_instance/gpu/batchnorm_backward.hpp +++ b/library/include/ck/library/tensor_operation_instance/gpu/batchnorm_backward.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/library/include/ck/library/tensor_operation_instance/gpu/batchnorm_forward.hpp b/library/include/ck/library/tensor_operation_instance/gpu/batchnorm_forward.hpp index 8e40d60c17b8dc8d743e29cc1472a1e9eed9bdc1..8fd1c7665d143e9a0fb07f47f07ead18aa33aa88 100644 --- a/library/include/ck/library/tensor_operation_instance/gpu/batchnorm_forward.hpp +++ b/library/include/ck/library/tensor_operation_instance/gpu/batchnorm_forward.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/library/include/ck/library/tensor_operation_instance/gpu/batchnorm_infer.hpp b/library/include/ck/library/tensor_operation_instance/gpu/batchnorm_infer.hpp index 342ade69cdf499c8a8a960038c288d4921a813b6..f6f4df7e2ec0e34b6ac7a4b356649289f6a30c91 100644 --- a/library/include/ck/library/tensor_operation_instance/gpu/batchnorm_infer.hpp +++ b/library/include/ck/library/tensor_operation_instance/gpu/batchnorm_infer.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/library/include/ck/library/tensor_operation_instance/gpu/contraction_bilinear.hpp b/library/include/ck/library/tensor_operation_instance/gpu/contraction_bilinear.hpp index 6a551c72683abaf3008ced53b11ae079cc4dc717..bd3af891ec2984e8841536a63ba0de950d8aab5c 100644 --- a/library/include/ck/library/tensor_operation_instance/gpu/contraction_bilinear.hpp +++ b/library/include/ck/library/tensor_operation_instance/gpu/contraction_bilinear.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #pragma once @@ -16,7 +16,7 @@ namespace ck { namespace tensor_operation { namespace device { namespace instance { - +#ifdef CK_ENABLE_FP32 // float void add_device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_f32_kknn_instance( std::vector>>& instances); - +#endif +#ifdef CK_ENABLE_FP64 // double void add_device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f64_f64_f64_f64_kknn_instance( std::vector>>& instances); - +#endif // Contraction + Bilinear template > op_ptrs; - +#ifdef CK_ENABLE_FP32 if constexpr(is_same_v && is_same_v && is_same_v && is_same_v) { @@ -165,7 +166,8 @@ struct DeviceOperationInstanceFactory && is_same_v && is_same_v && is_same_v) { @@ -181,7 +183,7 @@ struct DeviceOperationInstanceFactory>>& instances); - +#endif +#ifdef CK_ENABLE_FP64 // double void add_device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f64_f64_f64_kkn_instance( std::vector>>& instances); - +#endif // Contraction + Scale template > op_ptrs; - +#ifdef CK_ENABLE_FP32 if constexpr(is_same_v && is_same_v && is_same_v) { @@ -164,7 +165,8 @@ struct DeviceOperationInstanceFactory && is_same_v && is_same_v) { @@ -180,7 +182,7 @@ struct DeviceOperationInstanceFactory>>& instances); - +#endif +#ifdef CK_ENABLE_FP16 void add_device_conv1d_bwd_data_xdl_nwc_kxc_nwk_f16_instances( std::vector>>& instances); - +#endif +#ifdef CK_ENABLE_FP32 void add_device_conv1d_bwd_data_xdl_nwc_kxc_nwk_f32_instances( std::vector>>& instances); - +#endif +#ifdef CK_ENABLE_INT8 void add_device_conv1d_bwd_data_xdl_nwc_kxc_nwk_int8_instances( std::vector>>& instances); - +#endif +#ifdef CK_ENABLE_BF16 // conv2d backward data void add_device_conv2d_bwd_data_xdl_nhwc_kyxc_nhwk_bf16_instances( std::vector>>& instances); - +#endif +#ifdef CK_ENABLE_FP16 void add_device_conv2d_bwd_data_xdl_nhwc_kyxc_nhwk_f16_instances( std::vector>>& instances); - +#endif +#ifdef CK_ENABLE_FP32 void add_device_conv2d_bwd_data_xdl_nhwc_kyxc_nhwk_f32_instances( std::vector>>& instances); - +#endif +#ifdef CK_ENABLE_INT8 void add_device_conv2d_bwd_data_xdl_nhwc_kyxc_nhwk_int8_instances( std::vector>>& instances); - +#endif +#ifdef DL_KERNELS +#ifdef CK_ENABLE_FP16 // conv2d dl void add_device_conv2d_bwd_data_dl_nhwc_kyxc_nhwk_f16_instances( std::vector>>& instances); - +#endif +#ifdef CK_ENABLE_FP32 void add_device_conv2d_bwd_data_dl_nhwc_kyxc_nhwk_f32_instances( std::vector>>& instances); - +#endif +#ifdef CK_ENABLE_INT8 void add_device_conv2d_bwd_data_dl_nhwc_kyxc_nhwk_int8_instances( std::vector>>& instances); +#endif +#endif +#ifdef CK_ENABLE_BF16 // conv3d backward data void add_device_conv3d_bwd_data_xdl_ndhwc_kzyxc_ndhwk_bf16_instances( std::vector>>& instances); - +#endif +#ifdef CK_ENABLE_FP16 void add_device_conv3d_bwd_data_xdl_ndhwc_kzyxc_ndhwk_f16_instances( std::vector>>& instances); - +#endif +#ifdef CK_ENABLE_FP32 void add_device_conv3d_bwd_data_xdl_ndhwc_kzyxc_ndhwk_f32_instances( std::vector>>& instances); - +#endif +#ifdef CK_ENABLE_INT8 void add_device_conv3d_bwd_data_xdl_ndhwc_kzyxc_ndhwk_int8_instances( std::vector>>& instances); - +#endif template && is_same_v && - is_same_v) +#ifdef CK_ENABLE_FP16 + if constexpr(is_same_v && is_same_v && + is_same_v) { add_device_conv1d_bwd_data_xdl_nwc_kxc_nwk_f16_instances(op_ptrs); } - else if constexpr(is_same_v && - is_same_v && - is_same_v) +#endif +#ifdef CK_ENABLE_BF16 + if constexpr(is_same_v && + is_same_v && is_same_v) { add_device_conv1d_bwd_data_xdl_nwc_kxc_nwk_bf16_instances(op_ptrs); } - else if constexpr(is_same_v && is_same_v && - is_same_v) +#endif +#ifdef CK_ENABLE_INT8 + if constexpr(is_same_v && is_same_v && + is_same_v) { add_device_conv1d_bwd_data_xdl_nwc_kxc_nwk_int8_instances(op_ptrs); } +#endif } else if constexpr(NumDimSpatial == 2 && is_same_v && is_same_v && is_same_v) @@ -252,26 +274,37 @@ struct DeviceOperationInstanceFactory) { add_device_conv2d_bwd_data_xdl_nhwc_kyxc_nhwk_f32_instances(op_ptrs); +#ifdef DL_KERNELS add_device_conv2d_bwd_data_dl_nhwc_kyxc_nhwk_f32_instances(op_ptrs); +#endif } - else if constexpr(is_same_v && is_same_v && - is_same_v) +#ifdef CK_ENABLE_FP16 + if constexpr(is_same_v && is_same_v && + is_same_v) { add_device_conv2d_bwd_data_xdl_nhwc_kyxc_nhwk_f16_instances(op_ptrs); +#ifdef DL_KERNELS add_device_conv2d_bwd_data_dl_nhwc_kyxc_nhwk_f16_instances(op_ptrs); +#endif } - else if constexpr(is_same_v && - is_same_v && - is_same_v) +#endif +#ifdef CK_ENABLE_BF16 + if constexpr(is_same_v && + is_same_v && is_same_v) { add_device_conv2d_bwd_data_xdl_nhwc_kyxc_nhwk_bf16_instances(op_ptrs); } - else if constexpr(is_same_v && is_same_v && - is_same_v) +#endif +#ifdef CK_ENABLE_INT8 + if constexpr(is_same_v && is_same_v && + is_same_v) { add_device_conv2d_bwd_data_xdl_nhwc_kyxc_nhwk_int8_instances(op_ptrs); +#ifdef DL_KERNELS add_device_conv2d_bwd_data_dl_nhwc_kyxc_nhwk_int8_instances(op_ptrs); +#endif } +#endif } else if constexpr(NumDimSpatial == 3 && is_same_v && is_same_v && is_same_v) @@ -281,22 +314,27 @@ struct DeviceOperationInstanceFactory && is_same_v && - is_same_v) +#ifdef CK_ENABLE_FP16 + if constexpr(is_same_v && is_same_v && + is_same_v) { add_device_conv3d_bwd_data_xdl_ndhwc_kzyxc_ndhwk_f16_instances(op_ptrs); } - else if constexpr(is_same_v && - is_same_v && - is_same_v) +#endif +#ifdef CK_ENABLE_BF16 + if constexpr(is_same_v && + is_same_v && is_same_v) { add_device_conv3d_bwd_data_xdl_ndhwc_kzyxc_ndhwk_bf16_instances(op_ptrs); } - else if constexpr(is_same_v && is_same_v && - is_same_v) +#endif +#ifdef CK_ENABLE_INT8 + if constexpr(is_same_v && is_same_v && + is_same_v) { add_device_conv3d_bwd_data_xdl_ndhwc_kzyxc_ndhwk_int8_instances(op_ptrs); } +#endif } return op_ptrs; diff --git a/library/include/ck/library/tensor_operation_instance/gpu/convolution_forward.hpp b/library/include/ck/library/tensor_operation_instance/gpu/convolution_forward.hpp index 2c529e06fc55c1fd9c71d16c7c137a086924a75b..ad2da3364fe34ed581f4b4a40afa1627a13a660f 100644 --- a/library/include/ck/library/tensor_operation_instance/gpu/convolution_forward.hpp +++ b/library/include/ck/library/tensor_operation_instance/gpu/convolution_forward.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #pragma once @@ -18,11 +18,17 @@ namespace device { namespace instance { // conv2d forward +#ifdef CK_ENABLE_FP16 void add_device_conv2d_fwd_xdl_c_shuffle_nhwc_kyxc_nhwk_f16_instances( std::vector>>& instances); - +void add_device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_f16_instances( + std::vector>>& + instances); +#endif +#ifdef CK_ENABLE_BF16 void add_device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_bf16_instances( std::vector>>& instances); - -void add_device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_f16_instances( - std::vector>>& - instances); - +#endif +#ifdef CK_ENABLE_FP32 void add_device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_f32_instances( std::vector>>& instances); - +#endif +#ifdef CK_ENABLE_INT8 void add_device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_int8_instances( std::vector>>& instances); +#endif template && is_same_v && is_same_v) { add_device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_f16_instances(op_ptrs); add_device_conv2d_fwd_xdl_c_shuffle_nhwc_kyxc_nhwk_f16_instances(op_ptrs); } +#endif +#ifdef CK_ENABLE_BF16 else if constexpr(is_same_v && is_same_v && is_same_v) { add_device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_bf16_instances(op_ptrs); } +#endif +#ifdef CK_ENABLE_INT8 else if constexpr(is_same_v && is_same_v && is_same_v) { add_device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_int8_instances(op_ptrs); } +#endif } return op_ptrs; diff --git a/library/include/ck/library/tensor_operation_instance/gpu/device_elementwise_instance.hpp b/library/include/ck/library/tensor_operation_instance/gpu/device_elementwise_instance.hpp index 8af400cb7ad4daa31095cea130fdb8e2f9b4a842..b03693b00aabceef8614a56e91c0ce9b89207def 100644 --- a/library/include/ck/library/tensor_operation_instance/gpu/device_elementwise_instance.hpp +++ b/library/include/ck/library/tensor_operation_instance/gpu/device_elementwise_instance.hpp @@ -1,15 +1,14 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #pragma once #include #include "ck/ck.hpp" -#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" -#include "ck/tensor_operation/gpu/device/impl/device_elementwise_impl.hpp" +#include "ck/tensor_operation/gpu/device/device_elementwise.hpp" #include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" -#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" +#include "ck/library/tensor_operation_instance/device_operation_instance_factory.hpp" namespace ck { namespace tensor_operation { @@ -29,20 +28,34 @@ template -auto get_device_normalize_from_mean_meansquare_instances() +struct DeviceOperationInstanceFactory, + ck::Tuple, + Normalize, + 2>> { - std::vector op_ptrs; + using DeviceOp = DeviceElementwise< + ck::Tuple, + ck::Tuple, + Normalize, + 2>; - if constexpr(is_same::value && is_same::value && - is_same::value && is_same::value && - is_same::value && is_same::value) + static auto GetInstances() { - ck::tensor_operation::device::instance:: - add_device_normalize_from_mean_squaremean_f16_f32_f32_f16_f16_instances(op_ptrs); - } - - return op_ptrs; -} + std::vector> op_ptrs; + + if constexpr(is_same::value && is_same::value && + is_same::value && + is_same::value && + is_same::value && is_same::value) + { + ck::tensor_operation::device::instance:: + add_device_normalize_from_mean_squaremean_f16_f32_f32_f16_f16_instances(op_ptrs); + } + + return op_ptrs; + }; +}; } // namespace instance } // namespace device diff --git a/library/include/ck/library/tensor_operation_instance/gpu/device_gemm_mean_squaremean_instance.hpp b/library/include/ck/library/tensor_operation_instance/gpu/device_gemm_mean_squaremean_instance.hpp index 3d0c34062cbc6c34bf82c40678e1fb56fc64fca5..b15139510935586a930cd5a7c5a091299dfddecc 100644 --- a/library/include/ck/library/tensor_operation_instance/gpu/device_gemm_mean_squaremean_instance.hpp +++ b/library/include/ck/library/tensor_operation_instance/gpu/device_gemm_mean_squaremean_instance.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/library/include/ck/library/tensor_operation_instance/gpu/elementwise_normalization.hpp b/library/include/ck/library/tensor_operation_instance/gpu/elementwise_normalization.hpp index c87ae159bee2635c903490b9e09096be8c242c8c..931110267a9c99fb8020a593703ac14a7aed54d2 100644 --- a/library/include/ck/library/tensor_operation_instance/gpu/elementwise_normalization.hpp +++ b/library/include/ck/library/tensor_operation_instance/gpu/elementwise_normalization.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #pragma once @@ -11,7 +11,7 @@ #include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" #include "ck/library/tensor_operation_instance/device_operation_instance_factory.hpp" - +#ifdef CK_ENABLE_FP16 namespace ck { namespace tensor_operation { namespace device { @@ -77,3 +77,4 @@ struct DeviceOperationInstanceFactory>>& instances); +void add_device_gemm_dpp_f16_f16_f16_km_kn_mn_instances( + std::vector>>& + instances); + void add_device_gemm_dl_f16_f16_f16_km_kn_mn_irregular_instances( std::vector>>& @@ -32,6 +38,11 @@ void add_device_gemm_dl_f16_f16_f16_km_nk_mn_instances( DeviceGemm>>& instances); +void add_device_gemm_dpp_f16_f16_f16_km_nk_mn_instances( + std::vector>>& + instances); + void add_device_gemm_dl_f16_f16_f16_km_nk_mn_irregular_instances( std::vector>>& @@ -42,6 +53,11 @@ void add_device_gemm_dl_f16_f16_f16_mk_kn_mn_instances( DeviceGemm>>& instances); +void add_device_gemm_dpp_f16_f16_f16_mk_kn_mn_instances( + std::vector>>& + instances); + void add_device_gemm_dl_f16_f16_f16_mk_kn_mn_irregular_instances( std::vector>>& @@ -52,15 +68,20 @@ void add_device_gemm_dl_f16_f16_f16_mk_nk_mn_instances( DeviceGemm>>& instances); -void add_device_gemm_dl_f16_f16_f16_mk_nk_mn_irregular_instances( +void add_device_gemm_dpp_f16_f16_f16_mk_nk_mn_instances( std::vector>>& instances); +void add_device_gemm_dl_f16_f16_f16_mk_nk_mn_irregular_instances( + std::vector>>& + instances); +#endif +#if defined(CK_ENABLE_FP32) && defined(DL_KERNELS) void add_device_gemm_dl_f32_f32_f32_km_kn_mn_instances( std::vector>>& - instances); void add_device_gemm_dl_f32_f32_f32_km_nk_mn_instances( @@ -77,7 +98,8 @@ void add_device_gemm_dl_f32_f32_f32_mk_nk_mn_instances( std::vector>>& instances); - +#endif +#if defined(CK_ENABLE_INT8) && defined(DL_KERNELS) void add_device_gemm_dl_i8_i8_i8_km_kn_mn_instances( std::vector>>& @@ -117,30 +139,32 @@ void add_device_gemm_dl_i8_i8_i8_mk_nk_mn_irregular_instances( std::vector>>& instances); - -void add_device_gemm_xdl_c_shuffle_2_stage_f16_f16_f16_mk_nk_mn_instances( +#endif +#ifdef CK_ENABLE_INT8 +void add_device_gemm_xdl_c_shuffle_i8_i8_i8_km_kn_mn_instances( std::vector>>& + DeviceGemm>>& instances); -void add_device_gemm_xdl_c_shuffle_bf16_bf16_bf16_km_kn_mn_instances( +void add_device_gemm_xdl_c_shuffle_i8_i8_i8_km_nk_mn_instances( std::vector>>& + DeviceGemm>>& instances); -void add_device_gemm_xdl_c_shuffle_bf16_bf16_bf16_km_nk_mn_instances( +void add_device_gemm_xdl_c_shuffle_i8_i8_i8_mk_kn_mn_instances( std::vector>>& + DeviceGemm>>& instances); -void add_device_gemm_xdl_c_shuffle_bf16_bf16_bf16_mk_kn_mn_instances( +void add_device_gemm_xdl_c_shuffle_i8_i8_i8_mk_nk_mn_instances( std::vector>>& + DeviceGemm>>& instances); - -void add_device_gemm_xdl_c_shuffle_bf16_bf16_bf16_mk_nk_mn_instances( +#endif +#ifdef CK_ENABLE_FP16 +void add_device_gemm_xdl_c_shuffle_2_stage_f16_f16_f16_mk_nk_mn_instances( std::vector>>& + DeviceGemm>>& instances); void add_device_gemm_xdl_c_shuffle_f16_f16_f16_km_kn_mn_instances( @@ -163,64 +187,67 @@ void add_device_gemm_xdl_c_shuffle_f16_f16_f16_mk_nk_mn_instances( DeviceGemm>>& instances); -void add_device_gemm_xdl_c_shuffle_f32_f32_f32_km_kn_mn_instances( +void add_device_gemm_xdl_f16_f16_f16_km_kn_mn_instances( std::vector>>& + DeviceGemm>>& instances); -void add_device_gemm_xdl_c_shuffle_f32_f32_f32_km_nk_mn_instances( +void add_device_gemm_xdl_f16_f16_f16_km_nk_mn_instances( std::vector>>& + DeviceGemm>>& instances); -void add_device_gemm_xdl_c_shuffle_f32_f32_f32_mk_kn_mn_instances( +void add_device_gemm_xdl_f16_f16_f16_mk_kn_mn_instances( std::vector>>& + DeviceGemm>>& instances); -void add_device_gemm_xdl_c_shuffle_f32_f32_f32_mk_nk_mn_instances( +void add_device_gemm_xdl_f16_f16_f16_mk_nk_mn_instances( std::vector>>& + DeviceGemm>>& instances); -void add_device_gemm_xdl_c_shuffle_i8_i8_i8_km_kn_mn_instances( +#endif +#ifdef CK_ENABLE_BF16 +void add_device_gemm_xdl_c_shuffle_bf16_bf16_bf16_km_kn_mn_instances( std::vector>>& + DeviceGemm>>& instances); -void add_device_gemm_xdl_c_shuffle_i8_i8_i8_km_nk_mn_instances( +void add_device_gemm_xdl_c_shuffle_bf16_bf16_bf16_km_nk_mn_instances( std::vector>>& + DeviceGemm>>& instances); -void add_device_gemm_xdl_c_shuffle_i8_i8_i8_mk_kn_mn_instances( +void add_device_gemm_xdl_c_shuffle_bf16_bf16_bf16_mk_kn_mn_instances( std::vector>>& + DeviceGemm>>& instances); -void add_device_gemm_xdl_c_shuffle_i8_i8_i8_mk_nk_mn_instances( +void add_device_gemm_xdl_c_shuffle_bf16_bf16_bf16_mk_nk_mn_instances( std::vector>>& + DeviceGemm>>& instances); - -void add_device_gemm_xdl_f16_f16_f16_km_kn_mn_instances( +#endif +#ifdef CK_ENABLE_FP32 +void add_device_gemm_xdl_c_shuffle_f32_f32_f32_km_kn_mn_instances( std::vector>>& + DeviceGemm>>& instances); -void add_device_gemm_xdl_f16_f16_f16_km_nk_mn_instances( +void add_device_gemm_xdl_c_shuffle_f32_f32_f32_km_nk_mn_instances( std::vector>>& + DeviceGemm>>& instances); -void add_device_gemm_xdl_f16_f16_f16_mk_kn_mn_instances( +void add_device_gemm_xdl_c_shuffle_f32_f32_f32_mk_kn_mn_instances( std::vector>>& + DeviceGemm>>& instances); -void add_device_gemm_xdl_f16_f16_f16_mk_nk_mn_instances( +void add_device_gemm_xdl_c_shuffle_f32_f32_f32_mk_nk_mn_instances( std::vector>>& + DeviceGemm>>& instances); void add_device_gemm_xdl_f32_f32_f32_km_kn_mn_instances( @@ -242,7 +269,8 @@ void add_device_gemm_xdl_f32_f32_f32_mk_nk_mn_instances( std::vector>>& instances); - +#endif +#ifdef CK_ENABLE_FP64 void add_device_gemm_xdl_f64_f64_f64_km_kn_mn_instances( std::vector>>& instances); - +#endif template ) { add_device_gemm_xdl_f32_f32_f32_mk_kn_mn_instances(op_ptrs); +#ifdef DL_KERNELS add_device_gemm_dl_f32_f32_f32_mk_kn_mn_instances(op_ptrs); +#endif add_device_gemm_xdl_c_shuffle_f32_f32_f32_mk_kn_mn_instances(op_ptrs); } else if constexpr(is_same_v && is_same_v && is_same_v) { add_device_gemm_xdl_f32_f32_f32_mk_nk_mn_instances(op_ptrs); +#ifdef DL_KERNELS add_device_gemm_dl_f32_f32_f32_mk_nk_mn_instances(op_ptrs); +#endif add_device_gemm_xdl_c_shuffle_f32_f32_f32_mk_nk_mn_instances(op_ptrs); } else if constexpr(is_same_v && is_same_v && is_same_v) { add_device_gemm_xdl_f32_f32_f32_km_kn_mn_instances(op_ptrs); +#ifdef DL_KERNELS add_device_gemm_dl_f32_f32_f32_km_kn_mn_instances(op_ptrs); +#endif add_device_gemm_xdl_c_shuffle_f32_f32_f32_km_kn_mn_instances(op_ptrs); } else if constexpr(is_same_v && is_same_v && is_same_v) { add_device_gemm_xdl_f32_f32_f32_km_nk_mn_instances(op_ptrs); +#ifdef DL_KERNELS add_device_gemm_dl_f32_f32_f32_km_nk_mn_instances(op_ptrs); +#endif add_device_gemm_xdl_c_shuffle_f32_f32_f32_km_nk_mn_instances(op_ptrs); } } +#ifdef CK_ENABLE_FP16 else if constexpr(is_same_v && is_same_v && is_same_v) { @@ -334,16 +371,22 @@ struct DeviceOperationInstanceFactory< is_same_v) { add_device_gemm_xdl_f16_f16_f16_mk_kn_mn_instances(op_ptrs); +#ifdef DL_KERNELS add_device_gemm_dl_f16_f16_f16_mk_kn_mn_instances(op_ptrs); add_device_gemm_dl_f16_f16_f16_mk_kn_mn_irregular_instances(op_ptrs); + add_device_gemm_dpp_f16_f16_f16_mk_kn_mn_instances(op_ptrs); +#endif add_device_gemm_xdl_c_shuffle_f16_f16_f16_mk_kn_mn_instances(op_ptrs); } else if constexpr(is_same_v && is_same_v && is_same_v) { add_device_gemm_xdl_f16_f16_f16_mk_nk_mn_instances(op_ptrs); +#ifdef DL_KERNELS add_device_gemm_dl_f16_f16_f16_mk_nk_mn_instances(op_ptrs); add_device_gemm_dl_f16_f16_f16_mk_nk_mn_irregular_instances(op_ptrs); + add_device_gemm_dpp_f16_f16_f16_mk_nk_mn_instances(op_ptrs); +#endif add_device_gemm_xdl_c_shuffle_f16_f16_f16_mk_nk_mn_instances(op_ptrs); add_device_gemm_xdl_c_shuffle_2_stage_f16_f16_f16_mk_nk_mn_instances(op_ptrs); } @@ -351,19 +394,27 @@ struct DeviceOperationInstanceFactory< is_same_v) { add_device_gemm_xdl_f16_f16_f16_km_kn_mn_instances(op_ptrs); +#ifdef DL_KERNELS add_device_gemm_dl_f16_f16_f16_km_kn_mn_instances(op_ptrs); add_device_gemm_dl_f16_f16_f16_km_kn_mn_irregular_instances(op_ptrs); + add_device_gemm_dpp_f16_f16_f16_km_kn_mn_instances(op_ptrs); +#endif add_device_gemm_xdl_c_shuffle_f16_f16_f16_km_kn_mn_instances(op_ptrs); } else if constexpr(is_same_v && is_same_v && is_same_v) { add_device_gemm_xdl_f16_f16_f16_km_nk_mn_instances(op_ptrs); +#ifdef DL_KERNELS add_device_gemm_dl_f16_f16_f16_km_nk_mn_instances(op_ptrs); add_device_gemm_dl_f16_f16_f16_km_nk_mn_irregular_instances(op_ptrs); + add_device_gemm_dpp_f16_f16_f16_km_nk_mn_instances(op_ptrs); +#endif add_device_gemm_xdl_c_shuffle_f16_f16_f16_km_nk_mn_instances(op_ptrs); } } +#endif +#ifdef CK_ENABLE_BF16 else if constexpr(is_same_v && is_same_v && is_same_v) { @@ -388,6 +439,8 @@ struct DeviceOperationInstanceFactory< add_device_gemm_xdl_c_shuffle_bf16_bf16_bf16_km_nk_mn_instances(op_ptrs); } } +#endif +#ifdef CK_ENABLE_INT8 else if constexpr(is_same_v && is_same_v && is_same_v) { @@ -395,32 +448,40 @@ struct DeviceOperationInstanceFactory< is_same_v) { add_device_gemm_xdl_c_shuffle_i8_i8_i8_mk_kn_mn_instances(op_ptrs); +#ifdef DL_KERNELS add_device_gemm_dl_i8_i8_i8_mk_kn_mn_instances(op_ptrs); add_device_gemm_dl_i8_i8_i8_mk_kn_mn_irregular_instances(op_ptrs); +#endif } else if constexpr(is_same_v && is_same_v && is_same_v) { add_device_gemm_xdl_c_shuffle_i8_i8_i8_mk_nk_mn_instances(op_ptrs); +#ifdef DL_KERNELS add_device_gemm_dl_i8_i8_i8_mk_nk_mn_instances(op_ptrs); add_device_gemm_dl_i8_i8_i8_mk_nk_mn_irregular_instances(op_ptrs); +#endif } else if constexpr(is_same_v && is_same_v && is_same_v) { add_device_gemm_xdl_c_shuffle_i8_i8_i8_km_kn_mn_instances(op_ptrs); +#ifdef DL_KERNELS add_device_gemm_dl_i8_i8_i8_km_kn_mn_instances(op_ptrs); add_device_gemm_dl_i8_i8_i8_km_kn_mn_irregular_instances(op_ptrs); +#endif } else if constexpr(is_same_v && is_same_v && is_same_v) { add_device_gemm_xdl_c_shuffle_i8_i8_i8_km_nk_mn_instances(op_ptrs); +#ifdef DL_KERNELS add_device_gemm_dl_i8_i8_i8_km_nk_mn_instances(op_ptrs); add_device_gemm_dl_i8_i8_i8_km_nk_mn_irregular_instances(op_ptrs); +#endif } } - +#endif return op_ptrs; } }; diff --git a/library/include/ck/library/tensor_operation_instance/gpu/gemm_add_add_fastgelu.hpp b/library/include/ck/library/tensor_operation_instance/gpu/gemm_add_add_fastgelu.hpp index 2d578cca4fb37cbcce9fe5fad6c79446379d40e1..99b2ad13152f0baa82872add5efdeb9b6a1988a8 100644 --- a/library/include/ck/library/tensor_operation_instance/gpu/gemm_add_add_fastgelu.hpp +++ b/library/include/ck/library/tensor_operation_instance/gpu/gemm_add_add_fastgelu.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/library/include/ck/library/tensor_operation_instance/gpu/gemm_add_fastgelu.hpp b/library/include/ck/library/tensor_operation_instance/gpu/gemm_add_fastgelu.hpp index 554437f4903bcae96f68f2a53097a6eaf88367f3..fd3550c2f01b733602e6ae84dcca34f74255c9fe 100644 --- a/library/include/ck/library/tensor_operation_instance/gpu/gemm_add_fastgelu.hpp +++ b/library/include/ck/library/tensor_operation_instance/gpu/gemm_add_fastgelu.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/library/include/ck/library/tensor_operation_instance/gpu/gemm_add_multiply.hpp b/library/include/ck/library/tensor_operation_instance/gpu/gemm_add_multiply.hpp index c07ca3134bb889e5d15fc5df521979d5ea2efe46..481915d00b7416a444eabdf92289d275a9f60ef9 100644 --- a/library/include/ck/library/tensor_operation_instance/gpu/gemm_add_multiply.hpp +++ b/library/include/ck/library/tensor_operation_instance/gpu/gemm_add_multiply.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/library/include/ck/library/tensor_operation_instance/gpu/gemm_add_relu_add_layernorm.hpp b/library/include/ck/library/tensor_operation_instance/gpu/gemm_add_relu_add_layernorm.hpp index 7beae83cdc1799814a2e7f4a1a4e7fa9e2d59a08..dd8ecae62ce72ce675d845c5b36d2467574c6acc 100644 --- a/library/include/ck/library/tensor_operation_instance/gpu/gemm_add_relu_add_layernorm.hpp +++ b/library/include/ck/library/tensor_operation_instance/gpu/gemm_add_relu_add_layernorm.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #pragma once @@ -9,7 +9,7 @@ #include "ck/ck.hpp" #include "ck/tensor_operation/gpu/device/device_gemm_multiple_d_layernorm.hpp" #include "ck/library/tensor_operation_instance/device_operation_instance_factory.hpp" - +#ifdef CK_ENABLE_FP16 namespace ck { namespace tensor_operation { namespace device { @@ -170,3 +170,4 @@ struct DeviceOperationInstanceFactory +#include +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/device_gemm_multiple_d.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" + +#include "ck/library/tensor_operation_instance/device_operation_instance_factory.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +void add_device_gemm_multiply_add_xdl_c_shuffle_f16_f16_f16_f16_f16_mk_kn_mn_mn_mn_instances( + std::vector>>&); + +void add_device_gemm_multiply_add_xdl_c_shuffle_f16_f16_f16_f16_f16_mk_nk_mn_mn_mn_instances( + std::vector>>&); + +void add_device_gemm_multiply_add_xdl_c_shuffle_f16_f8_f32_f32_f16_mk_kn_mn_mn_mn_instances( + std::vector>>&); + +void add_device_gemm_multiply_add_xdl_c_shuffle_f16_f8_f32_f32_f16_mk_nk_mn_mn_mn_instances( + std::vector>>&); + +// GEMM + Multiply + Add +template +struct DeviceOperationInstanceFactory, + ELayout, + ADataType, + BDataType, + ck::Tuple, + EDataType, + ck::tensor_operation::element_wise::PassThrough, + ck::tensor_operation::element_wise::PassThrough, + ck::tensor_operation::element_wise::MultiplyAdd>> +{ + using DeviceOp = DeviceGemmMultipleD, + ELayout, + ADataType, + BDataType, + ck::Tuple, + EDataType, + ck::tensor_operation::element_wise::PassThrough, + ck::tensor_operation::element_wise::PassThrough, + ck::tensor_operation::element_wise::MultiplyAdd>; + + static auto GetInstances() + { + std::vector> op_ptrs; + + if constexpr(is_same_v && is_same_v && + is_same_v && is_same_v && + is_same_v) + { + if constexpr(is_same_v && is_same_v && + is_same_v && is_same_v && + is_same_v) + { + add_device_gemm_multiply_add_xdl_c_shuffle_f16_f16_f16_f16_f16_mk_kn_mn_mn_mn_instances( + op_ptrs); + } + else if constexpr(is_same_v && is_same_v && + is_same_v && is_same_v && + is_same_v) + { + add_device_gemm_multiply_add_xdl_c_shuffle_f16_f16_f16_f16_f16_mk_nk_mn_mn_mn_instances( + op_ptrs); + } + } + + if constexpr(is_same_v && is_same_v && + is_same_v && is_same_v && + is_same_v) + { + if constexpr(is_same_v && is_same_v && + is_same_v && is_same_v && + is_same_v) + { + add_device_gemm_multiply_add_xdl_c_shuffle_f16_f8_f32_f32_f16_mk_kn_mn_mn_mn_instances( + op_ptrs); + } + else if constexpr(is_same_v && is_same_v && + is_same_v && is_same_v && + is_same_v) + { + add_device_gemm_multiply_add_xdl_c_shuffle_f16_f8_f32_f32_f16_mk_nk_mn_mn_mn_instances( + op_ptrs); + } + } + + return op_ptrs; + } +}; + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/include/ck/library/tensor_operation_instance/gpu/gemm_splitk.hpp b/library/include/ck/library/tensor_operation_instance/gpu/gemm_splitk.hpp index e88844694831cc0bc97b76ec7651cea998385947..dbd2b8f65695d6886fd979b53f46ef1146d3549e 100644 --- a/library/include/ck/library/tensor_operation_instance/gpu/gemm_splitk.hpp +++ b/library/include/ck/library/tensor_operation_instance/gpu/gemm_splitk.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #pragma once @@ -57,6 +57,46 @@ void add_device_gemm_xdl_splitk_f32_f32_f32_mk_nk_mn_instances( DeviceGemmSplitK>>& instances); +void add_device_gemm_xdl_splitk_f8_f16_f16_km_kn_mn_instances( + std::vector>>& + instances); + +void add_device_gemm_xdl_splitk_f8_f16_f16_km_nk_mn_instances( + std::vector>>& + instances); + +void add_device_gemm_xdl_splitk_f8_f16_f16_mk_kn_mn_instances( + std::vector>>& + instances); + +void add_device_gemm_xdl_splitk_f8_f16_f16_mk_nk_mn_instances( + std::vector>>& + instances); + +void add_device_gemm_xdl_splitk_f16_f8_f16_km_kn_mn_instances( + std::vector>>& + instances); + +void add_device_gemm_xdl_splitk_f16_f8_f16_km_nk_mn_instances( + std::vector>>& + instances); + +void add_device_gemm_xdl_splitk_f16_f8_f16_mk_kn_mn_instances( + std::vector>>& + instances); + +void add_device_gemm_xdl_splitk_f16_f8_f16_mk_nk_mn_instances( + std::vector>>& + instances); + template && is_same_v && + is_same_v) + { + if constexpr(is_same_v && is_same_v && + is_same_v) + { + add_device_gemm_xdl_splitk_f8_f16_f16_mk_kn_mn_instances(op_ptrs); + } + else if constexpr(is_same_v && is_same_v && + is_same_v) + { + add_device_gemm_xdl_splitk_f8_f16_f16_mk_nk_mn_instances(op_ptrs); + } + else if constexpr(is_same_v && is_same_v && + is_same_v) + { + add_device_gemm_xdl_splitk_f8_f16_f16_km_kn_mn_instances(op_ptrs); + } + else if constexpr(is_same_v && is_same_v && + is_same_v) + { + add_device_gemm_xdl_splitk_f8_f16_f16_km_nk_mn_instances(op_ptrs); + } + } + else if constexpr(is_same_v && is_same_v && + is_same_v) + { + if constexpr(is_same_v && is_same_v && + is_same_v) + { + add_device_gemm_xdl_splitk_f16_f8_f16_mk_kn_mn_instances(op_ptrs); + } + else if constexpr(is_same_v && is_same_v && + is_same_v) + { + add_device_gemm_xdl_splitk_f16_f8_f16_mk_nk_mn_instances(op_ptrs); + } + else if constexpr(is_same_v && is_same_v && + is_same_v) + { + add_device_gemm_xdl_splitk_f16_f8_f16_km_kn_mn_instances(op_ptrs); + } + else if constexpr(is_same_v && is_same_v && + is_same_v) + { + add_device_gemm_xdl_splitk_f16_f8_f16_km_nk_mn_instances(op_ptrs); + } + } return op_ptrs; } diff --git a/library/include/ck/library/tensor_operation_instance/gpu/gemm_streamk.hpp b/library/include/ck/library/tensor_operation_instance/gpu/gemm_streamk.hpp new file mode 100644 index 0000000000000000000000000000000000000000..2df378b0c6365d705138e9da9904f89403cce99d --- /dev/null +++ b/library/include/ck/library/tensor_operation_instance/gpu/gemm_streamk.hpp @@ -0,0 +1,122 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/device_gemm_streamk.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" + +#include "ck/library/tensor_operation_instance/device_operation_instance_factory.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { +#ifdef CK_ENABLE_FP16 +void add_device_gemm_xdl_streamk_f16_f16_f16_mk_kn_mn_instances( + std::vector>>& + instances); + +template +struct DeviceOperationInstanceFactory> +{ + using DeviceOp = DeviceGemmStreamK; + + static auto GetInstances() + { + std::vector> op_ptrs; +#if 0 + if constexpr(is_same_v && is_same_v && + is_same_v) + { + if constexpr(is_same_v && is_same_v && + is_same_v) + { + add_device_gemm_xdl_splitk_f32_f32_f32_mk_kn_mn_instances(op_ptrs); + } + else if constexpr(is_same_v && is_same_v && + is_same_v) + { + add_device_gemm_xdl_splitk_f32_f32_f32_mk_nk_mn_instances(op_ptrs); + } + else if constexpr(is_same_v && is_same_v && + is_same_v) + { + add_device_gemm_xdl_splitk_f32_f32_f32_km_kn_mn_instances(op_ptrs); + } + else if constexpr(is_same_v && is_same_v && + is_same_v) + { + add_device_gemm_xdl_splitk_f32_f32_f32_km_nk_mn_instances(op_ptrs); + } + } + else if constexpr(is_same_v && is_same_v && + is_same_v) + { + if constexpr(is_same_v && is_same_v && + is_same_v) + { + add_device_gemm_xdl_splitk_f16_f16_f16_mk_kn_mn_instances(op_ptrs); + } + else if constexpr(is_same_v && is_same_v && + is_same_v) + { + add_device_gemm_xdl_splitk_f16_f16_f16_mk_nk_mn_instances(op_ptrs); + } + else if constexpr(is_same_v && is_same_v && + is_same_v) + { + add_device_gemm_xdl_splitk_f16_f16_f16_km_kn_mn_instances(op_ptrs); + } + else if constexpr(is_same_v && is_same_v && + is_same_v) + { + add_device_gemm_xdl_splitk_f16_f16_f16_km_nk_mn_instances(op_ptrs); + } + } +#endif + if constexpr(is_same_v && is_same_v && + is_same_v) + { + if constexpr(is_same_v && is_same_v && + is_same_v) + { + add_device_gemm_xdl_streamk_f16_f16_f16_mk_kn_mn_instances(op_ptrs); + } + } + return op_ptrs; + } +}; + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck +#endif diff --git a/library/include/ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_data/device_grouped_conv_bwd_data_xdl_instance.hpp b/library/include/ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_data/device_grouped_conv_bwd_data_xdl_instance.hpp new file mode 100644 index 0000000000000000000000000000000000000000..26acf4f5f7413dc8d9087bcee07d0d30e3161559 --- /dev/null +++ b/library/include/ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_data/device_grouped_conv_bwd_data_xdl_instance.hpp @@ -0,0 +1,149 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_data_multiple_d_xdl_cshuffle_v1.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" + +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +using BF16 = ck::bhalf_t; +using F16 = ck::half_t; +using F32 = float; + +using Empty_Tuple = ck::Tuple<>; + +template +using S = ck::Sequence; + +using namespace ck::tensor_layout::convolution; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; + +static constexpr auto ConvBwdDataDefault = ConvolutionBackwardDataSpecialization::Default; + +static constexpr auto ConvBwdDataFilter1x1Stride1Pad0 = + ConvolutionBackwardDataSpecialization::Filter1x1Stride1Pad0; + +// f16_f16_f32_f16 +template +using device_grouped_conv_bwd_data_xdl_f16_instances = + std::tuple< + // clang-format off + // ##############################################| NDim| ALayout| BLayout| DsLayout| ELayout| AData| BData| AccData| CShuffle| DsData| EData| AElementwise| BElementwise| CDEElementwise| ConvolutionBackward| DoPad| DoPad| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffleMXdl| CShuffleNXdl| CDEBlockTransfer| CDEBlockTransfer| + // ##############################################| Spatial| | | | | Type| Type| Type| DataType| Type| Type| Operation| Operation| Operation| DataSpecialization| GemmM| GemmN| PrefetchStage| Size| Block| Block| Block| | | XDL| XDL| PerWave| PerWave| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraN| PerWave| PerWave| _MBlock_MPerBlock| ScalarPerVector| + // ##############################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | Lengths_AK0_M_AK1| ArrangeOrder| | | PerVector| PerVector_AK1| | Lengths_BK0_N_BK1| ArrangeOrder| | | PerVector| PerVector_BK1| | PerShuffle| PerShuffle| _NBlock_NPerBlock| _NPerBlock| + // ##############################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + // generic instance + DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 8, 1, S<4, 8, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 16, 1, 4>, 1>, + // instances for small conv.K and conv.C + DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 8, 1, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 4, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 32, 1, 4>, 1>, + + DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 8>, + DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 8, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 8>, + DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 8>, + DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 8, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 8>, + DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 8, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 4, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 8>, + DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 8>, + DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 4, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 8>, + DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 8, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 8> + // clang-format on + >; + +// bf16_bf16_f32_bf16 +template +using device_grouped_conv_bwd_data_xdl_bf16_instances = std::tuple< + // clang-format off + // ##############################################| NDim| ALayout| BLayout| DsLayout| ELayout| AData| BData| AccData| CShuffle| DsData| EData| AElementwise| BElementwise| CDEElementwise| ConvolutionBackward| DoPad| DoPad| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffleMXdl| CShuffleNXdl| CDEBlockTransfer| CDEBlockTransfer| + // ##############################################| Spatial| | | | | Type| Type| Type| DataType| Type| Type| Operation| Operation| Operation| DataSpecialization| GemmM| GemmN| PrefetchStage| Size| Block| Block| Block| | | XDL| XDL| PerWave| PerWave| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraN| PerWave| PerWave| _MBlock_MPerBlock| ScalarPerVector| + // ##############################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | Lengths_AK0_M_AK1| ArrangeOrder| | | PerVector| PerVector_AK1| | Lengths_BK0_N_BK1| ArrangeOrder| | | PerVector| PerVector_BK1| | PerShuffle| PerShuffle| _NBlock_NPerBlock| _NPerBlock| + // ##############################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + // generic instance + DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 8, 1, S<4, 8, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 16, 1, 4>, 1>, + // instances for small conv.K and conv.C + DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 8, 1, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 4, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 32, 1, 4>, 1>, + + DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 8>, + DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 8, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 8>, + DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 8>, + DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 8, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 8>, + DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 8, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 4, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 8>, + DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 8>, + DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 4, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 8>, + DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 8, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 8> + // clang-format on + >; + +// f32_f32_f32_f32 +template +using device_grouped_conv_bwd_data_xdl_f32_instances = + std::tuple< + // clang-format off + // ##############################################| NDim| ALayout| BLayout| DsLayout| ELayout| AData| BData| AccData| CShuffle| DsData| EData| AElementwise| BElementwise| CDEElementwise| ConvolutionBackward| DoPad| DoPad| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffleMXdl| CShuffleNXdl| CDEBlockTransfer| CDEBlockTransfer| + // ##############################################| Spatial| | | | | Type| Type| Type| DataType| Type| Type| Operation| Operation| Operation| DataSpecialization| GemmM| GemmN| PrefetchStage| Size| Block| Block| Block| | | XDL| XDL| PerWave| PerWave| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraN| PerWave| PerWave| _MBlock_MPerBlock| ScalarPerVector| + // ##############################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | Lengths_AK0_M_AK1| ArrangeOrder| | | PerVector| PerVector_AK1| | Lengths_BK0_N_BK1| ArrangeOrder| | | PerVector| PerVector_BK1| | PerShuffle| PerShuffle| _NBlock_NPerBlock| _NPerBlock| + // ##############################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + // generic instance + DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1< NDimSpatial, ALayout, BLayout, DsLayout, ELayout, F32, F32, F32, F32, Empty_Tuple, F32, PassThrough, PassThrough, PassThrough, ConvSpec, true, true, 1, 64, 64, 64, 32, 8, 8, 32, 32, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 4, 1, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 4, 1, 1, 1, S<1, 16, 1, 4>, 1>, + // instances for small conv.K and conv.C + DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1< NDimSpatial, ALayout, BLayout, DsLayout, ELayout, F32, F32, F32, F32, Empty_Tuple, F32, PassThrough, PassThrough, PassThrough, ConvSpec, true, true, 1, 256, 64, 128, 32, 8, 8, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 4, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 1, 1, 1, S<1, 32, 1, 8>, 4>, + DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1< NDimSpatial, ALayout, BLayout, DsLayout, ELayout, F32, F32, F32, F32, Empty_Tuple, F32, PassThrough, PassThrough, PassThrough, ConvSpec, true, true, 1, 128, 128, 32, 32, 8, 8, 32, 32, 2, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 8, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 4, 1, 1, 1, S<1, 32, 1, 4>, 1>, + + DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1< NDimSpatial, ALayout, BLayout, DsLayout, ELayout, F32, F32, F32, F32, Empty_Tuple, F32, PassThrough, PassThrough, PassThrough, ConvSpec, true, true, 1, 256, 128, 256, 32, 8, 2, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<1, 32, 1, 8>, 4>, + DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1< NDimSpatial, ALayout, BLayout, DsLayout, ELayout, F32, F32, F32, F32, Empty_Tuple, F32, PassThrough, PassThrough, PassThrough, ConvSpec, true, true, 1, 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 1, 1, 1, S<1, 32, 1, 8>, 4>, + DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1< NDimSpatial, ALayout, BLayout, DsLayout, ELayout, F32, F32, F32, F32, Empty_Tuple, F32, PassThrough, PassThrough, PassThrough, ConvSpec, true, true, 1, 256, 128, 256, 32, 8, 8, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 1, 1, 1, S<1, 32, 1, 8>, 4>, + DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1< NDimSpatial, ALayout, BLayout, DsLayout, ELayout, F32, F32, F32, F32, Empty_Tuple, F32, PassThrough, PassThrough, PassThrough, ConvSpec, true, true, 1, 128, 128, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 1, 1, 1, S<1, 16, 1, 8>, 4>, + DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1< NDimSpatial, ALayout, BLayout, DsLayout, ELayout, F32, F32, F32, F32, Empty_Tuple, F32, PassThrough, PassThrough, PassThrough, ConvSpec, true, true, 1, 256, 128, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 1, 1, 1, S<1, 32, 1, 8>, 4>, + DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1< NDimSpatial, ALayout, BLayout, DsLayout, ELayout, F32, F32, F32, F32, Empty_Tuple, F32, PassThrough, PassThrough, PassThrough, ConvSpec, true, true, 1, 128, 128, 64, 32, 8, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 1, 1, 1, S<1, 32, 1, 4>, 4>, + DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1< NDimSpatial, ALayout, BLayout, DsLayout, ELayout, F32, F32, F32, F32, Empty_Tuple, F32, PassThrough, PassThrough, PassThrough, ConvSpec, true, true, 1, 128, 64, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 1, 1, 1, S<1, 16, 1, 8>, 4>, + DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1< NDimSpatial, ALayout, BLayout, DsLayout, ELayout, F32, F32, F32, F32, Empty_Tuple, F32, PassThrough, PassThrough, PassThrough, ConvSpec, true, true, 1, 64, 64, 64, 32, 8, 8, 32, 32, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 1, 1, 1, S<1, 16, 1, 4>, 4>, + DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1< NDimSpatial, ALayout, BLayout, DsLayout, ELayout, F32, F32, F32, F32, Empty_Tuple, F32, PassThrough, PassThrough, PassThrough, ConvSpec, true, true, 1, 256, 128, 64, 32, 8, 8, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 1, 1, 1, S<1, 32, 1, 8>, 4>, + DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1< NDimSpatial, ALayout, BLayout, DsLayout, ELayout, F32, F32, F32, F32, Empty_Tuple, F32, PassThrough, PassThrough, PassThrough, ConvSpec, true, true, 1, 256, 64, 128, 32, 8, 8, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 1, 1, 1, S<1, 32, 1, 8>, 4>, + DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1< NDimSpatial, ALayout, BLayout, DsLayout, ELayout, F32, F32, F32, F32, Empty_Tuple, F32, PassThrough, PassThrough, PassThrough, ConvSpec, true, true, 1, 128, 128, 32, 32, 8, 8, 32, 32, 2, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 8, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 1, 1, 1, S<1, 32, 1, 4>, 4>, + DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1< NDimSpatial, ALayout, BLayout, DsLayout, ELayout, F32, F32, F32, F32, Empty_Tuple, F32, PassThrough, PassThrough, PassThrough, ConvSpec, true, true, 1, 128, 32, 128, 32, 8, 8, 32, 32, 1, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 1, 1, 1, S<1, 16, 1, 8>, 4>, + DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1< NDimSpatial, ALayout, BLayout, DsLayout, ELayout, F32, F32, F32, F32, Empty_Tuple, F32, PassThrough, PassThrough, PassThrough, ConvSpec, true, true, 1, 64, 64, 32, 32, 8, 8, 32, 32, 2, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 8, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 1, 1, 1, S<1, 16, 1, 4>, 4>, + DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1< NDimSpatial, ALayout, BLayout, DsLayout, ELayout, F32, F32, F32, F32, Empty_Tuple, F32, PassThrough, PassThrough, PassThrough, ConvSpec, true, true, 1, 64, 32, 64, 32, 8, 8, 32, 32, 1, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 8, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 8> + // clang-format on + >; + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/include/ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_weight/device_grouped_conv_bwd_weight_xdl_instance.hpp b/library/include/ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_weight/device_grouped_conv_bwd_weight_xdl_instance.hpp new file mode 100644 index 0000000000000000000000000000000000000000..6fc91565b96e4ca991413870fc0592aa1a416ff4 --- /dev/null +++ b/library/include/ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_weight/device_grouped_conv_bwd_weight_xdl_instance.hpp @@ -0,0 +1,139 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_xdl_cshuffle.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" + +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +using namespace ck::tensor_layout::convolution; + +using BF16 = ck::bhalf_t; +using F16 = ck::half_t; +using F32 = float; + +using Empty_Tuple = ck::Tuple<>; + +template +using S = ck::Sequence; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; + +static constexpr auto ConvBwdWeightDefault = + ck::tensor_operation::device::ConvolutionBackwardWeightSpecialization::Default; + +static constexpr auto ConvBwdWeightFilter1x1Stride1Pad0 = + ck::tensor_operation::device::ConvolutionBackwardWeightSpecialization::Filter1x1Stride1Pad0; + +template +using device_grouped_conv_bwd_weight_xdl_c_shuffle_f32_instances = std::tuple< + // clang-format off + //#########################################| Num| InLayout| WeiLayout| OutLayout| InData| WeiData| OutData| AccData| In| Wei| Out| ConvBackward| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransfer| CBlockTransfer| + //#########################################| Dim| | | | Type| Type| Type| Type| Elementwise| Elementwise| Elementwise| Weight| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| ClusterLengths| ScalarPerVector| + //#########################################| Spatial| | | | | | | | Operation| Operation| Operation| Specialization| | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| MBlock_MPerBlock| NWaveNPerXdl| + //#########################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | NBlock_NPerBlock| | + // generic instance + DeviceGroupedConvBwdWeight_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 64, 64, 64, 4, 4, 32, 32, 2, 2, S<1, 4, 16, 1>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 1, 4, true, S<1, 4, 16, 1>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 1, 4, true, 1, 1, S<1, 16, 1, 4>, 1>, + // instances for small conv.K and conv.C + DeviceGroupedConvBwdWeight_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 128, 128, 32, 4, 4, 32, 32, 2, 1, S<1, 4, 32, 1>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 4, 4, true, S<1, 4, 8, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 1, 1, true, 1, 1, S<1, 32, 1, 4>, 1>, + DeviceGroupedConvBwdWeight_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 64, 32, 64, 4, 4, 32, 32, 1, 2, S<1, 4, 8, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 1, 2, true, S<1, 4, 16, 1>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 4, 4, true, 1, 1, S<1, 16, 1, 4>, 4>, + + DeviceGroupedConvBwdWeight_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 256, 256, 128, 4, 4, 32, 32, 4, 2, S<1, 4, 64, 1>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 4, 4, true, S<1, 4, 32, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 4, 2, true, 1, 1, S<1, 32, 1, 8>, 4>, + DeviceGroupedConvBwdWeight_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 256, 128, 256, 4, 4, 32, 32, 2, 4, S<1, 4, 32, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 4, 2, true, S<1, 4, 64, 1>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 4, 4, true, 1, 1, S<1, 32, 1, 8>, 4>, + DeviceGroupedConvBwdWeight_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 128, 128, 128, 4, 4, 32, 32, 4, 2, S<1, 4, 32, 1>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 4, 4, true, S<1, 4, 32, 1>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 4, 4, true, 1, 1, S<1, 32, 1, 4>, 4>, + DeviceGroupedConvBwdWeight_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 256, 128, 128, 4, 4, 32, 32, 2, 2, S<1, 4, 32, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 4, 2, true, S<1, 4, 32, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 4, 2, true, 1, 1, S<1, 32, 1, 4>, 4>, + DeviceGroupedConvBwdWeight_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 128, 128, 64, 4, 4, 32, 32, 2, 2, S<1, 4, 32, 1>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 4, 4, true, S<1, 4, 16, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 4, 2, true, 1, 1, S<1, 32, 1, 4>, 4>, + DeviceGroupedConvBwdWeight_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 128, 64, 128, 4, 4, 32, 32, 2, 2, S<1, 4, 16, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 4, 2, true, S<1, 4, 32, 1>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 4, 4, true, 1, 1, S<1, 32, 1, 4>, 4>, + DeviceGroupedConvBwdWeight_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 64, 64, 64, 4, 4, 32, 32, 2, 2, S<1, 4, 16, 1>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 4, 4, true, S<1, 4, 16, 1>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 4, 4, true, 1, 1, S<1, 16, 1, 4>, 4>, + DeviceGroupedConvBwdWeight_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 256, 128, 64, 4, 4, 32, 32, 2, 1, S<1, 4, 32, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 4, 2, true, S<1, 4, 16, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 4, 1, true, 1, 1, S<1, 32, 1, 4>, 4>, + DeviceGroupedConvBwdWeight_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 256, 64, 128, 4, 4, 32, 32, 1, 2, S<1, 4, 16, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 4, 1, true, S<1, 4, 32, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 4, 2, true, 1, 1, S<1, 32, 1, 4>, 4>, + DeviceGroupedConvBwdWeight_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 128, 128, 32, 4, 4, 32, 32, 2, 1, S<1, 4, 32, 1>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 4, 4, true, S<1, 4, 8, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 4, 1, true, 1, 1, S<1, 32, 1, 4>, 4>, + DeviceGroupedConvBwdWeight_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 128, 32, 128, 4, 4, 32, 32, 1, 2, S<1, 4, 8, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 4, 1, true, S<1, 4, 32, 1>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 4, 4, true, 1, 1, S<1, 32, 1, 4>, 4>, + DeviceGroupedConvBwdWeight_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 64, 64, 32, 4, 4, 32, 32, 2, 1, S<1, 4, 16, 1>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 4, 4, true, S<1, 4, 8, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 4, 2, true, 1, 1, S<1, 16, 1, 4>, 4>, + DeviceGroupedConvBwdWeight_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 64, 32, 64, 4, 4, 32, 32, 1, 2, S<1, 4, 8, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 4, 2, true, S<1, 4, 16, 1>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 4, 4, true, 1, 1, S<1, 16, 1, 4>, 4> + // clang-format on + >; + +template +using device_grouped_conv_bwd_weight_xdl_c_shuffle_f16_instances = std::tuple< + // clang-format off + //#########################################| Num| InLayout| WeiLayout| OutLayout| InData| WeiData| OutData| AccData| In| Wei| Out| ConvBackward| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransfer| CBlockTransfer| + //#########################################| Dim| | | | Type| Type| Type| Type| Elementwise| Elementwise| Elementwise| Weight| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| ClusterLengths| ScalarPerVector| + //#########################################| Spatial| | | | | | | | Operation| Operation| Operation| Specialization| | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| MBlock_MPerBlock| NWaveNPerXdl| + //#########################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | NBlock_NPerBlock| | + // generic instance + DeviceGroupedConvBwdWeight_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 64, 64, 64, 4, 8, 32, 32, 2, 2, S<1, 4, 8, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 2, 4, true, S<1, 4, 8, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 2, 4, true, 1, 1, S<1, 16, 1, 4>, 2>, + // instance for small conv.K + // for fp16 conv.K and conv.C must be divisible by 2 + // since half_t atomic_add require scalar_per_x_vector % 2 == 0 + DeviceGroupedConvBwdWeight_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 128, 128, 32, 4, 8, 32, 32, 2, 1, S<1, 4, 16, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, S<1, 4, 4, 8>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 2, 1, true, 1, 1, S<1, 32, 1, 4>, 2>, + DeviceGroupedConvBwdWeight_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 64, 32, 64, 4, 8, 32, 32, 1, 2, S<1, 4, 4, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 2, 2, true, S<1, 4, 8, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, 1, 1, S<1, 16, 1, 4>, 8>, + + DeviceGroupedConvBwdWeight_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 256, 256, 128, 4, 8, 32, 32, 4, 2, S<1, 4, 32, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, S<1, 4, 16, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 2, true, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGroupedConvBwdWeight_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 256, 128, 256, 4, 8, 32, 32, 2, 4, S<1, 4, 16, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 2, true, S<1, 4, 32, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGroupedConvBwdWeight_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 128, 128, 128, 4, 8, 32, 32, 4, 2, S<1, 4, 16, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, S<1, 4, 16, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, 1, 1, S<1, 32, 1, 4>, 8>, + DeviceGroupedConvBwdWeight_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 256, 128, 128, 4, 8, 32, 32, 2, 2, S<1, 4, 16, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 2, true, S<1, 4, 16, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 2, true, 1, 1, S<1, 32, 1, 4>, 8>, + DeviceGroupedConvBwdWeight_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 128, 128, 64, 4, 8, 32, 32, 2, 2, S<1, 4, 16, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, S<1, 4, 8, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 2, true, 1, 1, S<1, 32, 1, 4>, 8>, + DeviceGroupedConvBwdWeight_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 128, 64, 128, 4, 8, 32, 32, 2, 2, S<1, 4, 8, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 2, true, S<1, 4, 16, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, 1, 1, S<1, 32, 1, 4>, 8>, + DeviceGroupedConvBwdWeight_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 64, 64, 64, 4, 8, 32, 32, 2, 2, S<1, 4, 8, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, S<1, 4, 8, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, 1, 1, S<1, 16, 1, 4>, 8>, + DeviceGroupedConvBwdWeight_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 256, 128, 64, 4, 8, 32, 32, 2, 1, S<1, 4, 16, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 2, true, S<1, 4, 8, 8>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 1, true, 1, 1, S<1, 32, 1, 4>, 8>, + DeviceGroupedConvBwdWeight_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 256, 64, 128, 4, 8, 32, 32, 1, 2, S<1, 4, 8, 8>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 1, true, S<1, 4, 16, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 2, true, 1, 1, S<1, 32, 1, 4>, 8>, + DeviceGroupedConvBwdWeight_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 128, 128, 32, 4, 8, 32, 32, 2, 1, S<1, 4, 16, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, S<1, 4, 4, 8>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 1, true, 1, 1, S<1, 32, 1, 4>, 8>, + DeviceGroupedConvBwdWeight_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 128, 32, 128, 4, 8, 32, 32, 1, 2, S<1, 4, 4, 8>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 1, true, S<1, 4, 16, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, 1, 1, S<1, 32, 1, 4>, 8>, + DeviceGroupedConvBwdWeight_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 64, 64, 32, 4, 8, 32, 32, 2, 1, S<1, 4, 8, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, S<1, 4, 4, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 2, true, 1, 1, S<1, 16, 1, 4>, 8>, + DeviceGroupedConvBwdWeight_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 64, 32, 64, 4, 8, 32, 32, 1, 2, S<1, 4, 4, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 2, true, S<1, 4, 8, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, 1, 1, S<1, 16, 1, 4>, 8> + // clang-format on + >; + +template +using device_grouped_conv_bwd_weight_xdl_c_shuffle_bf16_instances = std::tuple< + // clang-format off + //#########################################| Num| InLayout| WeiLayout| OutLayout| InData| WeiData| OutData| AccData| In| Wei| Out| ConvBackward| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransfer| CBlockTransfer| + //#########################################| Dim| | | | Type| Type| Type| Type| Elementwise| Elementwise| Elementwise| Weight| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| ClusterLengths| ScalarPerVector| + //#########################################| Spatial| | | | | | | | Operation| Operation| Operation| Specialization| | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| MBlock_MPerBlock| NWaveNPerXdl| + //#########################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | NBlock_NPerBlock| | + // generic instance + DeviceGroupedConvBwdWeight_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, BF16, F32, BF16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 64, 64, 64, 4, 8, 32, 32, 2, 2, S<1, 4, 8, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 1, 4, true, S<1, 4, 8, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 1, 4, true, 1, 1, S<1, 16, 1, 4>, 1>, + // instance for small conv.K + DeviceGroupedConvBwdWeight_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, BF16, F32, BF16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 128, 128, 32, 4, 8, 32, 32, 2, 1, S<1, 4, 16, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, S<1, 4, 4, 8>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 1, 1, true, 1, 1, S<1, 32, 1, 4>, 1>, + DeviceGroupedConvBwdWeight_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, BF16, F32, BF16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 64, 32, 64, 4, 8, 32, 32, 1, 2, S<1, 4, 4, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 1, 2, true, S<1, 4, 8, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, 1, 1, S<1, 16, 1, 4>, 4>, + + DeviceGroupedConvBwdWeight_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, BF16, F32, BF16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 256, 256, 128, 4, 8, 32, 32, 4, 2, S<1, 4, 32, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, S<1, 4, 16, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 2, true, 1, 1, S<1, 32, 1, 8>, 4>, + DeviceGroupedConvBwdWeight_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, BF16, F32, BF16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 256, 128, 256, 4, 8, 32, 32, 2, 4, S<1, 4, 16, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 2, true, S<1, 4, 32, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, 1, 1, S<1, 32, 1, 8>, 4>, + DeviceGroupedConvBwdWeight_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, BF16, F32, BF16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 128, 128, 128, 4, 8, 32, 32, 4, 2, S<1, 4, 16, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, S<1, 4, 16, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, 1, 1, S<1, 32, 1, 4>, 4>, + DeviceGroupedConvBwdWeight_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, BF16, F32, BF16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 256, 128, 128, 4, 8, 32, 32, 2, 2, S<1, 4, 16, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 2, true, S<1, 4, 16, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 2, true, 1, 1, S<1, 32, 1, 4>, 4>, + DeviceGroupedConvBwdWeight_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, BF16, F32, BF16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 128, 128, 64, 4, 8, 32, 32, 2, 2, S<1, 4, 16, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, S<1, 4, 8, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 2, true, 1, 1, S<1, 32, 1, 4>, 4>, + DeviceGroupedConvBwdWeight_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, BF16, F32, BF16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 128, 64, 128, 4, 8, 32, 32, 2, 2, S<1, 4, 8, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 2, true, S<1, 4, 16, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, 1, 1, S<1, 32, 1, 4>, 4>, + DeviceGroupedConvBwdWeight_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, BF16, F32, BF16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 64, 64, 64, 4, 8, 32, 32, 2, 2, S<1, 4, 8, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, S<1, 4, 8, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, 1, 1, S<1, 16, 1, 4>, 4>, + DeviceGroupedConvBwdWeight_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, BF16, F32, BF16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 256, 128, 64, 4, 8, 32, 32, 2, 1, S<1, 4, 16, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 2, true, S<1, 4, 8, 8>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 1, true, 1, 1, S<1, 32, 1, 4>, 4>, + DeviceGroupedConvBwdWeight_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, BF16, F32, BF16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 256, 64, 128, 4, 8, 32, 32, 1, 2, S<1, 4, 8, 8>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 1, true, S<1, 4, 16, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 2, true, 1, 1, S<1, 32, 1, 4>, 4>, + DeviceGroupedConvBwdWeight_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, BF16, F32, BF16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 128, 128, 32, 4, 8, 32, 32, 2, 1, S<1, 4, 16, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, S<1, 4, 4, 8>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 1, true, 1, 1, S<1, 32, 1, 4>, 4>, + DeviceGroupedConvBwdWeight_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, BF16, F32, BF16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 128, 32, 128, 4, 8, 32, 32, 1, 2, S<1, 4, 4, 8>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 1, true, S<1, 4, 16, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, 1, 1, S<1, 32, 1, 4>, 4>, + DeviceGroupedConvBwdWeight_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, BF16, F32, BF16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 64, 64, 32, 4, 8, 32, 32, 2, 1, S<1, 4, 8, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, S<1, 4, 4, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 2, true, 1, 1, S<1, 16, 1, 4>, 4>, + DeviceGroupedConvBwdWeight_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, BF16, F32, BF16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 64, 32, 64, 4, 8, 32, 32, 1, 2, S<1, 4, 4, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 2, true, S<1, 4, 8, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, 1, 1, S<1, 16, 1, 4>, 4> + // clang-format on + >; + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/include/ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv2d_fwd_wmma_instance.hpp b/library/include/ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv2d_fwd_wmma_instance.hpp new file mode 100644 index 0000000000000000000000000000000000000000..c9cf0f8e1d4682384085bcc4c1dd009dc74e0913 --- /dev/null +++ b/library/include/ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv2d_fwd_wmma_instance.hpp @@ -0,0 +1,134 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/convolution_forward_specialization.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_d_wmma_cshuffle.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +using BF16 = ck::bhalf_t; +using F16 = ck::half_t; +using F32 = float; +using I8 = int8_t; +using I32 = int32_t; + +using Empty_Tuple = ck::Tuple<>; + +template +using S = ck::Sequence; + +using NHWGC = ck::tensor_layout::convolution::NHWGC; +using GNHWC = ck::tensor_layout::convolution::GNHWC; + +using GKYXC = ck::tensor_layout::convolution::GKYXC; + +using NHWGK = ck::tensor_layout::convolution::NHWGK; +using GNHWK = ck::tensor_layout::convolution::GNHWK; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; + +static constexpr auto ConvFwdDefault = + ck::tensor_operation::device::ConvolutionForwardSpecialization::Default; + +static constexpr auto ConvFwd1x1P0 = + ck::tensor_operation::device::ConvolutionForwardSpecialization::Filter1x1Pad0; + +static constexpr auto ConvFwd1x1S1P0 = + ck::tensor_operation::device::ConvolutionForwardSpecialization::Filter1x1Stride1Pad0; + +static constexpr auto ConvFwdOddC = + ck::tensor_operation::device::ConvolutionForwardSpecialization::OddC; + +static constexpr auto GemmMNKPadding = ck::tensor_operation::device::GemmSpecialization::MNKPadding; + +template +using device_grouped_conv2d_fwd_wmma_f16_instances = std::tuple< + // clang-format off + //########################################| NumDim| A| B| Ds| E| AData| BData| Ds| EData| AccData| CShuffle| A| B| CDE| ConvForward| GEMM| Block| MPer| NPer| KPer| K1| MPer| NPer| MRepeat| NRepeat| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| + //########################################| Spatial| Layout| Layout| Layout| Layout| Type| Type| DataType| Type| Type| DataType| Elementwise| Elementwise| Elementwise| Specialization| Specialization| Size| Block| Block| Block| | WMMA| WMMA| | | ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| + //########################################| | | | | | | | | | | | Operation| Operation| Operation| | | | | | | | | | | | Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| + //########################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + // blocksize=256 + DeviceGroupedConvFwdMultipleD_Wmma_CShuffle< 2, ALayout, BLayout, DsLayout, ELayout, F16, F16, DsDatatype, F16, F32, F16, PassThrough, PassThrough, CDEElementOp, ConvSpec, GemmMNKPadding, 256, 128, 128, 4, 8, 16, 16, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGroupedConvFwdMultipleD_Wmma_CShuffle< 2, ALayout, BLayout, DsLayout, ELayout, F16, F16, DsDatatype, F16, F32, F16, PassThrough, PassThrough, CDEElementOp, ConvSpec, GemmMNKPadding, 256, 64, 256, 4, 8, 16, 16, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGroupedConvFwdMultipleD_Wmma_CShuffle< 2, ALayout, BLayout, DsLayout, ELayout, F16, F16, DsDatatype, F16, F32, F16, PassThrough, PassThrough, CDEElementOp, ConvSpec, GemmMNKPadding, 256, 256, 64, 4, 8, 16, 16, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGroupedConvFwdMultipleD_Wmma_CShuffle< 2, ALayout, BLayout, DsLayout, ELayout, F16, F16, DsDatatype, F16, F32, F16, PassThrough, PassThrough, CDEElementOp, ConvSpec, GemmMNKPadding, 256, 128, 128, 8, 8, 16, 16, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, + // blocksize=128 + DeviceGroupedConvFwdMultipleD_Wmma_CShuffle< 2, ALayout, BLayout, DsLayout, ELayout, F16, F16, DsDatatype, F16, F32, F16, PassThrough, PassThrough, CDEElementOp, ConvSpec, GemmMNKPadding, 128, 64, 64, 4, 8, 16, 16, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 8>, + DeviceGroupedConvFwdMultipleD_Wmma_CShuffle< 2, ALayout, BLayout, DsLayout, ELayout, F16, F16, DsDatatype, F16, F32, F16, PassThrough, PassThrough, CDEElementOp, ConvSpec, GemmMNKPadding, 128, 64, 64, 8, 8, 16, 16, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 8>, + DeviceGroupedConvFwdMultipleD_Wmma_CShuffle< 2, ALayout, BLayout, DsLayout, ELayout, F16, F16, DsDatatype, F16, F32, F16, PassThrough, PassThrough, CDEElementOp, ConvSpec, GemmMNKPadding, 128, 64, 128, 4, 8, 16, 16, 2, 4, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 8>, + DeviceGroupedConvFwdMultipleD_Wmma_CShuffle< 2, ALayout, BLayout, DsLayout, ELayout, F16, F16, DsDatatype, F16, F32, F16, PassThrough, PassThrough, CDEElementOp, ConvSpec, GemmMNKPadding, 128, 64, 128, 8, 8, 16, 16, 2, 4, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 8>, + DeviceGroupedConvFwdMultipleD_Wmma_CShuffle< 2, ALayout, BLayout, DsLayout, ELayout, F16, F16, DsDatatype, F16, F32, F16, PassThrough, PassThrough, CDEElementOp, ConvSpec, GemmMNKPadding, 128, 128, 64, 4, 8, 16, 16, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 8>, + DeviceGroupedConvFwdMultipleD_Wmma_CShuffle< 2, ALayout, BLayout, DsLayout, ELayout, F16, F16, DsDatatype, F16, F32, F16, PassThrough, PassThrough, CDEElementOp, ConvSpec, GemmMNKPadding, 128, 128, 64, 8, 8, 16, 16, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 8>, + DeviceGroupedConvFwdMultipleD_Wmma_CShuffle< 2, ALayout, BLayout, DsLayout, ELayout, F16, F16, DsDatatype, F16, F32, F16, PassThrough, PassThrough, CDEElementOp, ConvSpec, GemmMNKPadding, 128, 32, 256, 4, 8, 16, 16, 1, 8, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 8>, + DeviceGroupedConvFwdMultipleD_Wmma_CShuffle< 2, ALayout, BLayout, DsLayout, ELayout, F16, F16, DsDatatype, F16, F32, F16, PassThrough, PassThrough, CDEElementOp, ConvSpec, GemmMNKPadding, 128, 256, 32, 4, 8, 16, 16, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 8>, + // blocksize=64 + DeviceGroupedConvFwdMultipleD_Wmma_CShuffle< 2, ALayout, BLayout, DsLayout, ELayout, F16, F16, DsDatatype, F16, F32, F16, PassThrough, PassThrough, CDEElementOp, ConvSpec, GemmMNKPadding, 64, 32, 64, 4, 8, 16, 16, 1, 4, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 2>, 8>, + DeviceGroupedConvFwdMultipleD_Wmma_CShuffle< 2, ALayout, BLayout, DsLayout, ELayout, F16, F16, DsDatatype, F16, F32, F16, PassThrough, PassThrough, CDEElementOp, ConvSpec, GemmMNKPadding, 64, 64, 32, 4, 8, 16, 16, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 2>, 8>, + DeviceGroupedConvFwdMultipleD_Wmma_CShuffle< 2, ALayout, BLayout, DsLayout, ELayout, F16, F16, DsDatatype, F16, F32, F16, PassThrough, PassThrough, CDEElementOp, ConvSpec, GemmMNKPadding, 64, 32, 32, 8, 8, 16, 16, 1, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 2>, 8>, + DeviceGroupedConvFwdMultipleD_Wmma_CShuffle< 2, ALayout, BLayout, DsLayout, ELayout, F16, F16, DsDatatype, F16, F32, F16, PassThrough, PassThrough, CDEElementOp, ConvSpec, GemmMNKPadding, 64, 32, 128, 4, 8, 16, 16, 1, 8, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 2>, 8>, + // blocksize=32 + DeviceGroupedConvFwdMultipleD_Wmma_CShuffle< 2, ALayout, BLayout, DsLayout, ELayout, F16, F16, DsDatatype, F16, F32, F16, PassThrough, PassThrough, CDEElementOp, ConvSpec, GemmMNKPadding, 32, 16, 64, 4, 8, 16, 16, 1, 4, S<2, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<2, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 2>, 8>, + DeviceGroupedConvFwdMultipleD_Wmma_CShuffle< 2, ALayout, BLayout, DsLayout, ELayout, F16, F16, DsDatatype, F16, F32, F16, PassThrough, PassThrough, CDEElementOp, ConvSpec, GemmMNKPadding, 32, 64, 16, 4, 8, 16, 16, 4, 1, S<2, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<2, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 2>, 8>, + DeviceGroupedConvFwdMultipleD_Wmma_CShuffle< 2, ALayout, BLayout, DsLayout, ELayout, F16, F16, DsDatatype, F16, F32, F16, PassThrough, PassThrough, CDEElementOp, ConvSpec, GemmMNKPadding, 32, 32, 32, 4, 8, 16, 16, 2, 2, S<2, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<2, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 2>, 8>, + DeviceGroupedConvFwdMultipleD_Wmma_CShuffle< 2, ALayout, BLayout, DsLayout, ELayout, F16, F16, DsDatatype, F16, F32, F16, PassThrough, PassThrough, CDEElementOp, ConvSpec, GemmMNKPadding, 32, 16, 16, 4, 8, 16, 16, 1, 1, S<2, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<2, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 2>, 8> + // clang-format on + >; + +template +using device_grouped_conv2d_fwd_wmma_i8_instances = std::tuple< + // clang-format off + //########################################| NumDim| A| B| Ds| E| AData| BData| Ds| EData| AccData| CShuffle| A| B| CDE| ConvForward| GEMM| Block| MPer| NPer| KPer| K1| MPer| NPer| MRepeat| NRepeat| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| + //########################################| Spatial| Layout| Layout| Layout| Layout| Type| Type| DataType| Type| Type| DataType| Elementwise| Elementwise| Elementwise| Specialization| Specialization| Size| Block| Block| Block| | WMMA| WMMA| | | ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| + //########################################| | | | | | | | | | | | Operation| Operation| Operation| | | | | | | | | | | | Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| + //########################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + // blocksize=256 + DeviceGroupedConvFwdMultipleD_Wmma_CShuffle< 2, ALayout, BLayout, DsLayout, ELayout, I8, I8, DsDatatype, I8, I32, I8, PassThrough, PassThrough, CDEElementOp, ConvSpec, GemmMNKPadding, 256, 128, 128, 4, 16, 16, 16, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGroupedConvFwdMultipleD_Wmma_CShuffle< 2, ALayout, BLayout, DsLayout, ELayout, I8, I8, DsDatatype, I8, I32, I8, PassThrough, PassThrough, CDEElementOp, ConvSpec, GemmMNKPadding, 256, 64, 256, 4, 16, 16, 16, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGroupedConvFwdMultipleD_Wmma_CShuffle< 2, ALayout, BLayout, DsLayout, ELayout, I8, I8, DsDatatype, I8, I32, I8, PassThrough, PassThrough, CDEElementOp, ConvSpec, GemmMNKPadding, 256, 256, 64, 4, 16, 16, 16, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGroupedConvFwdMultipleD_Wmma_CShuffle< 2, ALayout, BLayout, DsLayout, ELayout, I8, I8, DsDatatype, I8, I32, I8, PassThrough, PassThrough, CDEElementOp, ConvSpec, GemmMNKPadding, 256, 128, 128, 8, 16, 16, 16, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, 1, 1, S<1, 32, 1, 8>, 8>, + // blocksize=128 + DeviceGroupedConvFwdMultipleD_Wmma_CShuffle< 2, ALayout, BLayout, DsLayout, ELayout, I8, I8, DsDatatype, I8, I32, I8, PassThrough, PassThrough, CDEElementOp, ConvSpec, GemmMNKPadding, 128, 64, 64, 4, 16, 16, 16, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, 1, 1, S<1, 32, 1, 4>, 8>, + DeviceGroupedConvFwdMultipleD_Wmma_CShuffle< 2, ALayout, BLayout, DsLayout, ELayout, I8, I8, DsDatatype, I8, I32, I8, PassThrough, PassThrough, CDEElementOp, ConvSpec, GemmMNKPadding, 128, 64, 64, 8, 16, 16, 16, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, 1, 1, S<1, 32, 1, 4>, 8>, + DeviceGroupedConvFwdMultipleD_Wmma_CShuffle< 2, ALayout, BLayout, DsLayout, ELayout, I8, I8, DsDatatype, I8, I32, I8, PassThrough, PassThrough, CDEElementOp, ConvSpec, GemmMNKPadding, 128, 64, 128, 4, 16, 16, 16, 2, 4, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, 1, 1, S<1, 32, 1, 4>, 8>, + DeviceGroupedConvFwdMultipleD_Wmma_CShuffle< 2, ALayout, BLayout, DsLayout, ELayout, I8, I8, DsDatatype, I8, I32, I8, PassThrough, PassThrough, CDEElementOp, ConvSpec, GemmMNKPadding, 128, 64, 128, 8, 16, 16, 16, 2, 4, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, 1, 1, S<1, 32, 1, 4>, 8>, + DeviceGroupedConvFwdMultipleD_Wmma_CShuffle< 2, ALayout, BLayout, DsLayout, ELayout, I8, I8, DsDatatype, I8, I32, I8, PassThrough, PassThrough, CDEElementOp, ConvSpec, GemmMNKPadding, 128, 128, 64, 4, 16, 16, 16, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, 1, 1, S<1, 32, 1, 4>, 8>, + DeviceGroupedConvFwdMultipleD_Wmma_CShuffle< 2, ALayout, BLayout, DsLayout, ELayout, I8, I8, DsDatatype, I8, I32, I8, PassThrough, PassThrough, CDEElementOp, ConvSpec, GemmMNKPadding, 128, 128, 64, 8, 16, 16, 16, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, 1, 1, S<1, 32, 1, 4>, 8>, + DeviceGroupedConvFwdMultipleD_Wmma_CShuffle< 2, ALayout, BLayout, DsLayout, ELayout, I8, I8, DsDatatype, I8, I32, I8, PassThrough, PassThrough, CDEElementOp, ConvSpec, GemmMNKPadding, 128, 32, 256, 4, 16, 16, 16, 1, 8, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, 1, 1, S<1, 32, 1, 4>, 8>, + DeviceGroupedConvFwdMultipleD_Wmma_CShuffle< 2, ALayout, BLayout, DsLayout, ELayout, I8, I8, DsDatatype, I8, I32, I8, PassThrough, PassThrough, CDEElementOp, ConvSpec, GemmMNKPadding, 128, 256, 32, 4, 16, 16, 16, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, 1, 1, S<1, 32, 1, 4>, 8>, + // blocksize=64 + DeviceGroupedConvFwdMultipleD_Wmma_CShuffle< 2, ALayout, BLayout, DsLayout, ELayout, I8, I8, DsDatatype, I8, I32, I8, PassThrough, PassThrough, CDEElementOp, ConvSpec, GemmMNKPadding, 64, 32, 64, 4, 16, 16, 16, 1, 4, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, 1, 1, S<1, 32, 1, 2>, 8>, + DeviceGroupedConvFwdMultipleD_Wmma_CShuffle< 2, ALayout, BLayout, DsLayout, ELayout, I8, I8, DsDatatype, I8, I32, I8, PassThrough, PassThrough, CDEElementOp, ConvSpec, GemmMNKPadding, 64, 64, 32, 4, 16, 16, 16, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, 1, 1, S<1, 32, 1, 2>, 8>, + DeviceGroupedConvFwdMultipleD_Wmma_CShuffle< 2, ALayout, BLayout, DsLayout, ELayout, I8, I8, DsDatatype, I8, I32, I8, PassThrough, PassThrough, CDEElementOp, ConvSpec, GemmMNKPadding, 64, 32, 32, 8, 16, 16, 16, 1, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, 1, 1, S<1, 32, 1, 2>, 8>, + DeviceGroupedConvFwdMultipleD_Wmma_CShuffle< 2, ALayout, BLayout, DsLayout, ELayout, I8, I8, DsDatatype, I8, I32, I8, PassThrough, PassThrough, CDEElementOp, ConvSpec, GemmMNKPadding, 64, 32, 128, 4, 16, 16, 16, 1, 8, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, 1, 1, S<1, 32, 1, 2>, 8>, + // blocksize=32 + DeviceGroupedConvFwdMultipleD_Wmma_CShuffle< 2, ALayout, BLayout, DsLayout, ELayout, I8, I8, DsDatatype, I8, I32, I8, PassThrough, PassThrough, CDEElementOp, ConvSpec, GemmMNKPadding, 32, 16, 64, 4, 16, 16, 16, 1, 4, S<2, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, S<2, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, 1, 1, S<1, 16, 1, 2>, 8>, + DeviceGroupedConvFwdMultipleD_Wmma_CShuffle< 2, ALayout, BLayout, DsLayout, ELayout, I8, I8, DsDatatype, I8, I32, I8, PassThrough, PassThrough, CDEElementOp, ConvSpec, GemmMNKPadding, 32, 64, 16, 4, 16, 16, 16, 4, 1, S<2, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, S<2, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, 1, 1, S<1, 16, 1, 2>, 8>, + DeviceGroupedConvFwdMultipleD_Wmma_CShuffle< 2, ALayout, BLayout, DsLayout, ELayout, I8, I8, DsDatatype, I8, I32, I8, PassThrough, PassThrough, CDEElementOp, ConvSpec, GemmMNKPadding, 32, 32, 32, 4, 16, 16, 16, 2, 2, S<2, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, S<2, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, 1, 1, S<1, 16, 1, 2>, 8>, + DeviceGroupedConvFwdMultipleD_Wmma_CShuffle< 2, ALayout, BLayout, DsLayout, ELayout, I8, I8, DsDatatype, I8, I32, I8, PassThrough, PassThrough, CDEElementOp, ConvSpec, GemmMNKPadding, 32, 16, 16, 4, 16, 16, 16, 1, 1, S<2, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, S<2, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, 1, 1, S<1, 16, 1, 2>, 8> + // clang-format on + >; + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/device_grouped_conv2d_fwd_dl_instance.hpp b/library/include/ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_dl_instance.hpp similarity index 64% rename from library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/device_grouped_conv2d_fwd_dl_instance.hpp rename to library/include/ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_dl_instance.hpp index bcda22006b3ee788f8db1dd2cf427e6d510d425c..86ff43e5cf6b9af3a9f70c37912ce26ebb3ce8f7 100644 --- a/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/device_grouped_conv2d_fwd_dl_instance.hpp +++ b/library/include/ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_dl_instance.hpp @@ -1,14 +1,45 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. -#include "ck/tensor_operation/gpu/device/device_grouped_conv_fwd_dl_multiple_d_nhwc_kyxc_nhwk.hpp" -#include "device_grouped_conv2d_fwd_common.hpp" +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_dl_multiple_d_nhwc_kyxc_nhwk.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/convolution_forward_specialization.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" namespace ck { namespace tensor_operation { namespace device { namespace instance { +using BF16 = ck::bhalf_t; +using F16 = ck::half_t; +using F32 = float; + +using Empty_Tuple = ck::Tuple<>; + +template +using S = ck::Sequence; + +using namespace ck::tensor_layout::convolution; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; + +static constexpr auto ConvFwdDefault = + ck::tensor_operation::device::ConvolutionForwardSpecialization::Default; + +static constexpr auto ConvFwd1x1P0 = + ck::tensor_operation::device::ConvolutionForwardSpecialization::Filter1x1Pad0; + +static constexpr auto ConvFwd1x1S1P0 = + ck::tensor_operation::device::ConvolutionForwardSpecialization::Filter1x1Stride1Pad0; + +static constexpr auto ConvFwdOddC = + ck::tensor_operation::device::ConvolutionForwardSpecialization::OddC; + +static constexpr auto GemmMNKPadding = ck::tensor_operation::device::GemmSpecialization::MNKPadding; + template will be supported + DeviceGroupedConvFwdDlMultipleD_NHWC_KYXC_NHWK< 2, F16, F16, DsDatatype, F16, F32, InLayout, WeiLayout, DsLayout, OutLayout, PassThrough, PassThrough, CDEElementOp, ConvSpec, GemmMNKPadding, 8, 16, 4, 2, 2, 1, 2, 1, S<4, 2>, S<1, 1>, S<2, 1, 2, 2>, S<1, 1, 8, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<1, 1, 1, 2>, S<2, 1, 4, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<1, 1, 1, 1>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<0, 1, 2, 3, 4, 5>, 5, 2>, + DeviceGroupedConvFwdDlMultipleD_NHWC_KYXC_NHWK< 2, F16, F16, DsDatatype, F16, F32, InLayout, WeiLayout, DsLayout, OutLayout, PassThrough, PassThrough, CDEElementOp, ConvSpec, GemmMNKPadding, 256, 128, 128, 16, 2, 4, 4, 1, S<8, 2>, S<8, 2>, S<8, 1, 1, 2>, S<2, 1, 128, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<8, 1, 1, 2>, S<2, 1, 128, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<1, 1, 1, 1>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<0, 1, 2, 3, 4, 5>, 5, 2>, + DeviceGroupedConvFwdDlMultipleD_NHWC_KYXC_NHWK< 2, F16, F16, DsDatatype, F16, F32, InLayout, WeiLayout, DsLayout, OutLayout, PassThrough, PassThrough, CDEElementOp, ConvSpec, GemmMNKPadding, 256, 128, 128, 16, 2, 4, 4, 1, S<8, 2>, S<8, 2>, S<8, 1, 1, 2>, S<2, 1, 128, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<8, 1, 1, 2>, S<2, 1, 128, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<0, 1, 2, 3, 4, 5>, 5, 4> // clang-format on >; @@ -40,6 +76,10 @@ using device_grouped_conv2d_fwd_dl_f32_instances = std::tuple< // ########################################| Spatial| Type| Type| Type| Type| Type| | | Layout| | Elementwise| Elementwise| Elementwise| Forward| Spacialization| Size| Block| Block| Block| | ThreadM111| ThreadN111| Thread| ClusterM110Xs| ClusterN110Xs| ThreadSliceLengths| ThreadClusterLengths| ThreadCluster| SrcAccess| SrcVectorTensor| SrcVectorTensor| DstVectorTensor| ThreadSliceLengths| ThreadClusterLengths| ThreadCluster| SrcAccess| SrcVectorTensor| SrcVectorTensor| DstVectorTensor| SrcDstAccess| SrcDstVectorDim| DstScalarPerVector| // ########################################| | | | | | | | | | | Operation| Operation| Operation| Specialization| | | | | | | | | | | | K0_M0_M1_K1| K0_M0_M1_K1| ArrangeOrder| Order| Lengths_K0_M0_M1_K1| ContiguousDimOrder| Lengths_K0_M0_M1_K1| K0_N0_N1_K1| K0_N0_N1_K1| ArrangeOrder| Order| Lengths_K0_N0_N1_K1| ContiguousDimOrder| Lengths_K0_N0_N1_K1| Order| | | // ########################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + // generic instances + DeviceGroupedConvFwdDlMultipleD_NHWC_KYXC_NHWK< 2, F32, F32, DsDatatype, F32, F32, InLayout, WeiLayout, DsLayout, OutLayout, PassThrough, PassThrough, CDEElementOp, ConvSpec, GemmMNKPadding, 8, 16, 4, 2, 1, 1, 2, 1, S<4, 2>, S<1, 1>, S<2, 1, 2, 1>, S<1, 1, 8, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<1, 1, 1, 1>, S<1, 2, 0, 3>, S<1, 1, 1, 1>, S<1, 1, 1, 1>, S<2, 1, 4, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<1, 1, 1, 1>, S<1, 2, 0, 3>, S<1, 1, 1, 1>, S<0, 1, 2, 3, 4, 5>, 5, 1>, + DeviceGroupedConvFwdDlMultipleD_NHWC_KYXC_NHWK< 2, F32, F32, DsDatatype, F32, F32, InLayout, WeiLayout, DsLayout, OutLayout, PassThrough, PassThrough, CDEElementOp, ConvSpec, GemmMNKPadding, 256, 128, 128, 16, 1, 4, 4, 1, S<8, 2>, S<8, 2>, S<8, 1, 1, 1>, S<2, 1, 128, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<1, 1, 1, 1>, S<1, 2, 0, 3>, S<1, 1, 1, 1>, S<8, 1, 1, 1>, S<2, 1, 128, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<1, 1, 1, 1>, S<1, 2, 0, 3>, S<1, 1, 1, 1>, S<0, 1, 2, 3, 4, 5>, 5, 1>, + DeviceGroupedConvFwdDlMultipleD_NHWC_KYXC_NHWK< 2, F32, F32, DsDatatype, F32, F32, InLayout, WeiLayout, DsLayout, OutLayout, PassThrough, PassThrough, CDEElementOp, ConvSpec, GemmMNKPadding, 256, 128, 128, 16, 1, 4, 4, 1, S<8, 2>, S<8, 2>, S<8, 1, 1, 1>, S<2, 1, 128, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 1>, S<1, 2, 0, 3>, S<1, 1, 1, 1>, S<8, 1, 1, 1>, S<2, 1, 128, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 1>, S<1, 2, 0, 3>, S<1, 1, 1, 1>, S<0, 1, 2, 3, 4, 5>, 5, 4> // clang-format on >; diff --git a/library/include/ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_instance.hpp b/library/include/ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_instance.hpp new file mode 100644 index 0000000000000000000000000000000000000000..23edf35e98cb8f059cc30994cc111b3853747bd7 --- /dev/null +++ b/library/include/ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_instance.hpp @@ -0,0 +1,180 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_d_xdl_cshuffle.hpp" +#include "ck/tensor_operation/gpu/device/convolution_forward_specialization.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +using BF16 = ck::bhalf_t; +using F16 = ck::half_t; +using F32 = float; + +template +using S = ck::Sequence; + +using Empty_Tuple = ck::Tuple<>; + +using namespace ck::tensor_layout::convolution; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; + +static constexpr auto ConvFwdDefault = + ck::tensor_operation::device::ConvolutionForwardSpecialization::Default; + +static constexpr auto ConvFwd1x1P0 = ConvolutionForwardSpecialization::Filter1x1Pad0; + +static constexpr auto ConvFwd1x1S1P0 = ConvolutionForwardSpecialization::Filter1x1Stride1Pad0; + +static constexpr auto ConvFwdOddC = + ck::tensor_operation::device::ConvolutionForwardSpecialization::OddC; + +static constexpr auto GemmMNKPadding = GemmSpecialization::MNKPadding; + +template +using device_grouped_conv_fwd_xdl_bf16_instances = std::tuple< + // clang-format off + //########################################| NumDim| A| B| Ds| E| AData| BData| AccData| CShuffle| Ds| EData| A| B| CDE| ConvForward| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| + //########################################| Spatial| Layout| Layout| Layout| Layout| Type| Type| Type| DataType| DataType| Type| Elementwise| Elementwise| Elementwise| Specialization| Specialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| + //########################################| | | | | | | | | | | | Operation| Operation| Operation| | | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| + //########################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + // generic instance + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 8, 1, 1, 1, S<1, 16, 1, 4>, 1>, + // instances for small conv.K and conv.C + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 1>, + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, + + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 8>, + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 8>, + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 8>, + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 8>, + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 8>, + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 8>, + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 8>, + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 8> + // clang-format on + >; + +template +using device_grouped_conv_fwd_xdl_f16_instances = std::tuple< + // clang-format off + //########################################| NumDim| A| B| Ds| E| AData| BData| AccData| CShuffle| Ds| EData| A| B| CDE| ConvForward| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| + //########################################| Spatial| Layout| Layout| Layout| Layout| Type| Type| Type| DataType| DataType| Type| Elementwise| Elementwise| Elementwise| Specialization| Specialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| + //########################################| | | | | | | | | | | | Operation| Operation| Operation| | | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| + //########################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + // generic instance + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 8, 1, 1, 1, S<1, 16, 1, 4>, 1>, + // instances for small conv.K and conv.C + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 1>, + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, + + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 8>, + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 8>, + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 8>, + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 8>, + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 8>, + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 8>, + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 8>, + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 8> + // clang-format on + >; + +template +using device_grouped_conv_fwd_xdl_f32_instances = std::tuple< + // clang-format off + //########################################| NumDim| A| B| Ds| E| AData| BData| AccData| CShuffle| Ds| EData| A| B| CDE| ConvForward| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| + //########################################| Spatial| Layout| Layout| Layout| Layout| Type| Type| Type| DataType| DataType| Type| Elementwise| Elementwise| Elementwise| Specialization| Specialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| + //########################################| | | | | | | | | | | | Operation| Operation| Operation| | | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| + //########################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + // generic instance + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 4, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 4, 1, 1, 1, S<1, 8, 1, 8>, 1>, + // instances for small conv.K and conv.C + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 8, 1, 8>, 1>, + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 4, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 4, 1, 1, 1, S<1, 16, 1, 16>, 4>, + + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 16, 1, 16>, 4>, + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 16, 1, 16>, 4>, + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 8, 1, 16>, 4>, + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 16, 1, 16>, 4>, + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 16, 1, 8>, 4>, + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 8, 1, 16>, 4>, + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 8, 1, 8>, 4>, + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 16, 1, 16>, 4>, + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 16, 1, 16>, 4>, + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 16, 1, 8>, 4>, + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 8, 1, 16>, 4>, + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 8, 1, 8>, 4>, + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 8, 1, 8>, 4> + // clang-format on + >; + +template +using device_grouped_conv_fwd_xdl_int8_instances = std::tuple< + // clang-format off + //########################################| NumDim| A| B| Ds| E| AData| BData| AccData| CShuffle| Ds| EData| A| B| CDE| ConvForward| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| + //########################################| Spatial| Layout| Layout| Layout| Layout| Type| Type| Type| DataType| DataType| Type| Elementwise| Elementwise| Elementwise| Specialization| Specialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| + //########################################| | | | | | | | | | | | Operation| Operation| Operation| | | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| + //########################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + // generic instance + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 8, 1, 1, 1, S<1, 16, 1, 4>, 1>, + // instances for small conv.K and conv.C + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 1>, + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, + + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 8>, + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 8>, + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 8>, + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 8>, + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 8>, + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 8>, + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 8>, + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 8> + // clang-format on + >; + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_backward_data.hpp b/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_backward_data.hpp index 81b2b4fcf37cb1f0d05415eb7d56587f0a29f380..58c9064535e771dfd5398ffc7f6e44e547d1979c 100644 --- a/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_backward_data.hpp +++ b/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_backward_data.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #pragma once @@ -16,7 +16,7 @@ namespace device { namespace instance { // conv2d backward data -void add_device_grouped_conv2d_bwd_data_xdl_gnhwc_gkyxc_gnhwk_f16_instances( +void add_device_grouped_conv2d_bwd_data_xdl_gnhwk_gkyxc_gnhwc_f16_instances( std::vector>>& instances); +void add_device_grouped_conv2d_bwd_data_xdl_gnhwk_gkyxc_gnhwc_f32_instances( + std::vector>>& instances); + +void add_device_grouped_conv2d_bwd_data_xdl_gnhwk_gkyxc_gnhwc_bf16_instances( + std::vector>>& instances); + +void add_device_grouped_conv2d_bwd_data_xdl_nhwgk_gkyxc_nhwgc_f16_instances( + std::vector>>& instances); + +void add_device_grouped_conv2d_bwd_data_xdl_nhwgk_gkyxc_nhwgc_f32_instances( + std::vector>>& instances); + +void add_device_grouped_conv2d_bwd_data_xdl_nhwgk_gkyxc_nhwgc_bf16_instances( + std::vector>>& instances); + +// conv3d backward data +void add_device_grouped_conv3d_bwd_data_xdl_gndhwk_gkzyxc_gndhwc_f16_instances( + std::vector>>& instances); + +void add_device_grouped_conv3d_bwd_data_xdl_gndhwk_gkzyxc_gndhwc_f32_instances( + std::vector>>& instances); + +void add_device_grouped_conv3d_bwd_data_xdl_gndhwk_gkzyxc_gndhwc_bf16_instances( + std::vector>>& instances); + +void add_device_grouped_conv3d_bwd_data_xdl_ndhwgk_gkzyxc_ndhwgc_f16_instances( + std::vector>>& instances); + +void add_device_grouped_conv3d_bwd_data_xdl_ndhwgk_gkzyxc_ndhwgc_f32_instances( + std::vector>>& instances); + +void add_device_grouped_conv3d_bwd_data_xdl_ndhwgk_gkzyxc_ndhwgc_bf16_instances( + std::vector>>& instances); + template > op_ptrs; + if constexpr(NumDimSpatial == 2) + { - if constexpr(NumDimSpatial == 2 && is_same_v && - is_same_v && is_same_v) + if constexpr(is_same_v && is_same_v && + is_same_v) + { + if constexpr(is_same_v && is_same_v && + is_same_v) + { + add_device_grouped_conv2d_bwd_data_xdl_gnhwk_gkyxc_gnhwc_f16_instances(op_ptrs); + } + else if constexpr(is_same_v && is_same_v && + is_same_v) + { + add_device_grouped_conv2d_bwd_data_xdl_gnhwk_gkyxc_gnhwc_f32_instances(op_ptrs); + } + else if constexpr(is_same_v && is_same_v && + is_same_v) + { + add_device_grouped_conv2d_bwd_data_xdl_gnhwk_gkyxc_gnhwc_bf16_instances( + op_ptrs); + } + } + else if constexpr(is_same_v && is_same_v && + is_same_v) + { + if constexpr(is_same_v && is_same_v && + is_same_v) + { + add_device_grouped_conv2d_bwd_data_xdl_nhwgk_gkyxc_nhwgc_f16_instances(op_ptrs); + } + else if constexpr(is_same_v && is_same_v && + is_same_v) + { + add_device_grouped_conv2d_bwd_data_xdl_nhwgk_gkyxc_nhwgc_f32_instances(op_ptrs); + } + else if constexpr(is_same_v && is_same_v && + is_same_v) + { + add_device_grouped_conv2d_bwd_data_xdl_nhwgk_gkyxc_nhwgc_bf16_instances( + op_ptrs); + } + } + } + else if constexpr(NumDimSpatial == 3) { - if constexpr(is_same_v && is_same_v && - is_same_v) + + if constexpr(is_same_v && is_same_v && + is_same_v) + { + if constexpr(is_same_v && is_same_v && + is_same_v) + { + add_device_grouped_conv3d_bwd_data_xdl_gndhwk_gkzyxc_gndhwc_f16_instances( + op_ptrs); + } + else if constexpr(is_same_v && is_same_v && + is_same_v) + { + add_device_grouped_conv3d_bwd_data_xdl_gndhwk_gkzyxc_gndhwc_f32_instances( + op_ptrs); + } + else if constexpr(is_same_v && is_same_v && + is_same_v) + { + add_device_grouped_conv3d_bwd_data_xdl_gndhwk_gkzyxc_gndhwc_bf16_instances( + op_ptrs); + } + } + else if constexpr(is_same_v && is_same_v && + is_same_v) { - add_device_grouped_conv2d_bwd_data_xdl_gnhwc_gkyxc_gnhwk_f16_instances(op_ptrs); + if constexpr(is_same_v && is_same_v && + is_same_v) + { + add_device_grouped_conv3d_bwd_data_xdl_ndhwgk_gkzyxc_ndhwgc_f16_instances( + op_ptrs); + } + else if constexpr(is_same_v && is_same_v && + is_same_v) + { + add_device_grouped_conv3d_bwd_data_xdl_ndhwgk_gkzyxc_ndhwgc_f32_instances( + op_ptrs); + } + else if constexpr(is_same_v && is_same_v && + is_same_v) + { + add_device_grouped_conv3d_bwd_data_xdl_ndhwgk_gkzyxc_ndhwgc_bf16_instances( + op_ptrs); + } } } diff --git a/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_backward_weight.hpp b/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_backward_weight.hpp index ef6920e52a2096a3f61e40435c6fa3ce64a17323..cad3e1ace82fce5dfb895459e415c8754c4f943e 100644 --- a/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_backward_weight.hpp +++ b/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_backward_weight.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #pragma once @@ -91,6 +91,42 @@ void add_device_grouped_conv2d_bwd_weight_xdl_gnhwc_gkyxc_gnhwk_f32_instances( PassThrough, PassThrough>>>& instances); +void add_device_grouped_conv2d_bwd_weight_xdl_nhwgc_gkyxc_nhwgk_bf16_f32_bf16_instances( + std::vector>>& instances); + +void add_device_grouped_conv2d_bwd_weight_xdl_nhwgc_gkyxc_nhwgk_f16_instances( + std::vector>>& instances); + +void add_device_grouped_conv2d_bwd_weight_xdl_nhwgc_gkyxc_nhwgk_f32_instances( + std::vector>>& instances); + // conv3d backward weight void add_device_grouped_conv3d_bwd_weight_xdl_gndhwc_gkzyxc_gndhwk_bf16_f32_bf16_instances( std::vector>>& instances); +void add_device_grouped_conv3d_bwd_weight_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_f32_bf16_instances( + std::vector>>& instances); + +void add_device_grouped_conv3d_bwd_weight_xdl_ndhwgc_gkzyxc_ndhwgk_f16_instances( + std::vector>>& instances); + +void add_device_grouped_conv3d_bwd_weight_xdl_ndhwgc_gkzyxc_ndhwgk_f32_instances( + std::vector>>& instances); + template > op_ptrs; - if constexpr(NumDimSpatial == 1 && is_same_v && - is_same_v && is_same_v) + if constexpr(NumDimSpatial == 1) { - if constexpr(is_same_v && is_same_v && - is_same_v) + if constexpr(is_same_v && is_same_v && + is_same_v) { - add_device_grouped_conv1d_bwd_weight_xdl_gnwc_gkxc_gnwk_f32_instances(op_ptrs); - } - else if constexpr(is_same_v && is_same_v && - is_same_v) - { - add_device_grouped_conv1d_bwd_weight_xdl_gnwc_gkxc_gnwk_f16_instances(op_ptrs); - } - else if constexpr(is_same_v && is_same_v && - is_same_v) - { - add_device_grouped_conv1d_bwd_weight_xdl_gnwc_gkxc_gnwk_bf16_f32_bf16_instances( - op_ptrs); + if constexpr(is_same_v && is_same_v && + is_same_v) + { + add_device_grouped_conv1d_bwd_weight_xdl_gnwc_gkxc_gnwk_f32_instances(op_ptrs); + } + else if constexpr(is_same_v && is_same_v && + is_same_v) + { + add_device_grouped_conv1d_bwd_weight_xdl_gnwc_gkxc_gnwk_f16_instances(op_ptrs); + } + else if constexpr(is_same_v && + is_same_v && + is_same_v) + { + add_device_grouped_conv1d_bwd_weight_xdl_gnwc_gkxc_gnwk_bf16_f32_bf16_instances( + op_ptrs); + } } } - else if constexpr(NumDimSpatial == 2 && is_same_v && - is_same_v && is_same_v) + else if constexpr(NumDimSpatial == 2) { - if constexpr(is_same_v && is_same_v && - is_same_v) + if constexpr(is_same_v && is_same_v && + is_same_v) { - add_device_grouped_conv2d_bwd_weight_xdl_gnhwc_gkyxc_gnhwk_f32_instances(op_ptrs); + if constexpr(is_same_v && is_same_v && + is_same_v) + { + add_device_grouped_conv2d_bwd_weight_xdl_gnhwc_gkyxc_gnhwk_f32_instances( + op_ptrs); + } + else if constexpr(is_same_v && is_same_v && + is_same_v) + { + add_device_grouped_conv2d_bwd_weight_xdl_gnhwc_gkyxc_gnhwk_f16_instances( + op_ptrs); + } + else if constexpr(is_same_v && + is_same_v && + is_same_v) + { + add_device_grouped_conv2d_bwd_weight_xdl_gnhwc_gkyxc_gnhwk_bf16_f32_bf16_instances( + op_ptrs); + } } - else if constexpr(is_same_v && is_same_v && - is_same_v) + else if constexpr(is_same_v && is_same_v && + is_same_v) { - add_device_grouped_conv2d_bwd_weight_xdl_gnhwc_gkyxc_gnhwk_f16_instances(op_ptrs); - } - else if constexpr(is_same_v && is_same_v && - is_same_v) - { - add_device_grouped_conv2d_bwd_weight_xdl_gnhwc_gkyxc_gnhwk_bf16_f32_bf16_instances( - op_ptrs); + if constexpr(is_same_v && is_same_v && + is_same_v) + { + add_device_grouped_conv2d_bwd_weight_xdl_nhwgc_gkyxc_nhwgk_f32_instances( + op_ptrs); + } + else if constexpr(is_same_v && is_same_v && + is_same_v) + { + add_device_grouped_conv2d_bwd_weight_xdl_nhwgc_gkyxc_nhwgk_f16_instances( + op_ptrs); + } + else if constexpr(is_same_v && + is_same_v && + is_same_v) + { + add_device_grouped_conv2d_bwd_weight_xdl_nhwgc_gkyxc_nhwgk_bf16_f32_bf16_instances( + op_ptrs); + } } } - else if constexpr(NumDimSpatial == 3 && is_same_v && - is_same_v && is_same_v) + else if constexpr(NumDimSpatial == 3) { - if constexpr(is_same_v && is_same_v && - is_same_v) - { - add_device_grouped_conv3d_bwd_weight_xdl_gndhwc_gkzyxc_gndhwk_f32_instances( - op_ptrs); - } - else if constexpr(is_same_v && is_same_v && - is_same_v) + if constexpr(is_same_v && is_same_v && + is_same_v) { - add_device_grouped_conv3d_bwd_weight_xdl_gndhwc_gkzyxc_gndhwk_f16_instances( - op_ptrs); + if constexpr(is_same_v && is_same_v && + is_same_v) + { + add_device_grouped_conv3d_bwd_weight_xdl_gndhwc_gkzyxc_gndhwk_f32_instances( + op_ptrs); + } + else if constexpr(is_same_v && is_same_v && + is_same_v) + { + add_device_grouped_conv3d_bwd_weight_xdl_gndhwc_gkzyxc_gndhwk_f16_instances( + op_ptrs); + } + else if constexpr(is_same_v && + is_same_v && + is_same_v) + { + add_device_grouped_conv3d_bwd_weight_xdl_gndhwc_gkzyxc_gndhwk_bf16_f32_bf16_instances( + op_ptrs); + } } - else if constexpr(is_same_v && is_same_v && - is_same_v) + else if constexpr(is_same_v && is_same_v && + is_same_v) { - add_device_grouped_conv3d_bwd_weight_xdl_gndhwc_gkzyxc_gndhwk_bf16_f32_bf16_instances( - op_ptrs); + if constexpr(is_same_v && is_same_v && + is_same_v) + { + add_device_grouped_conv3d_bwd_weight_xdl_ndhwgc_gkzyxc_ndhwgk_f32_instances( + op_ptrs); + } + else if constexpr(is_same_v && is_same_v && + is_same_v) + { + add_device_grouped_conv3d_bwd_weight_xdl_ndhwgc_gkzyxc_ndhwgk_f16_instances( + op_ptrs); + } + else if constexpr(is_same_v && + is_same_v && + is_same_v) + { + add_device_grouped_conv3d_bwd_weight_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_f32_bf16_instances( + op_ptrs); + } } } diff --git a/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_forward.hpp b/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_forward.hpp index fc4beb0ae9247d9ee9f573d345778d650fc55f69..576b9d8983c75b8acb022eaa453ba59bf3b4f84b 100644 --- a/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_forward.hpp +++ b/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_forward.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #pragma once @@ -145,6 +145,60 @@ void add_device_grouped_conv2d_fwd_dl_gnhwc_gkyxc_gnhwk_f32_instances( PassThrough, PassThrough>>>& instances); +void add_device_grouped_conv2d_fwd_wmma_gnhwc_gkyxc_gnhwk_f16_instances( + std::vector>>& instances); +void add_device_grouped_conv2d_fwd_dl_nhwgc_gkyxc_nhwgk_f16_instances( + std::vector>>& instances); + +void add_device_grouped_conv2d_fwd_wmma_gnhwc_gkyxc_gnhwk_i8_instances( + std::vector>>& instances); +void add_device_grouped_conv2d_fwd_dl_nhwgc_gkyxc_nhwgk_f32_instances( + std::vector>>& instances); + // grouped conv2d forward, NHWGC/GKYXC/NHWGK void add_device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_bf16_instances( std::vector>>& instances); -// grouped conv3d forward, NDHWGC/KZYXGC/NDHWGK -void add_device_grouped_conv3d_fwd_xdl_ndhwgc_kzyxgc_ndhwgk_bf16_instances( +// grouped conv3d forward, NDHWGC/GKZYXC/NDHWGK +void add_device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_instances( std::vector>>& instances); -void add_device_grouped_conv3d_fwd_xdl_ndhwgc_kzyxgc_ndhwgk_f16_instances( +void add_device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_f16_instances( std::vector>>& instances); -void add_device_grouped_conv3d_fwd_xdl_ndhwgc_kzyxgc_ndhwgk_f32_instances( +void add_device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_f32_instances( std::vector>>& instances); -void add_device_grouped_conv3d_fwd_xdl_ndhwgc_kzyxgc_ndhwgk_int8_instances( +void add_device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_int8_instances( std::vector && is_same_v && @@ -386,6 +441,11 @@ struct DeviceOperationInstanceFactory && is_same_v && + is_same_v) + { + add_device_grouped_conv2d_fwd_wmma_gnhwc_gkyxc_gnhwk_i8_instances(op_ptrs); + } } else if constexpr(NumDimSpatial == 2 && is_same_v && is_same_v && is_same_v) @@ -394,11 +454,13 @@ struct DeviceOperationInstanceFactory) { add_device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_f32_instances(op_ptrs); + add_device_grouped_conv2d_fwd_dl_nhwgc_gkyxc_nhwgk_f32_instances(op_ptrs); } else if constexpr(is_same_v && is_same_v && is_same_v) { add_device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_f16_instances(op_ptrs); + add_device_grouped_conv2d_fwd_dl_nhwgc_gkyxc_nhwgk_f16_instances(op_ptrs); } else if constexpr(is_same_v && is_same_v && @@ -433,28 +495,28 @@ struct DeviceOperationInstanceFactory && - is_same_v && is_same_v) + is_same_v && is_same_v) { if constexpr(is_same_v && is_same_v && is_same_v) { - add_device_grouped_conv3d_fwd_xdl_ndhwgc_kzyxgc_ndhwgk_f32_instances(op_ptrs); + add_device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_f32_instances(op_ptrs); } else if constexpr(is_same_v && is_same_v && is_same_v) { - add_device_grouped_conv3d_fwd_xdl_ndhwgc_kzyxgc_ndhwgk_f16_instances(op_ptrs); + add_device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_f16_instances(op_ptrs); } else if constexpr(is_same_v && is_same_v && is_same_v) { - add_device_grouped_conv3d_fwd_xdl_ndhwgc_kzyxgc_ndhwgk_bf16_instances(op_ptrs); + add_device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_instances(op_ptrs); } else if constexpr(is_same_v && is_same_v && is_same_v) { - add_device_grouped_conv3d_fwd_xdl_ndhwgc_kzyxgc_ndhwgk_int8_instances(op_ptrs); + add_device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_int8_instances(op_ptrs); } } diff --git a/library/include/ck/library/tensor_operation_instance/gpu/grouped_gemm.hpp b/library/include/ck/library/tensor_operation_instance/gpu/grouped_gemm.hpp index e97484a5add5a12c4c86be11782ed106f9f3ed61..070c7e5b17f94ed6b272557f29fa86d60941d615 100644 --- a/library/include/ck/library/tensor_operation_instance/gpu/grouped_gemm.hpp +++ b/library/include/ck/library/tensor_operation_instance/gpu/grouped_gemm.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #pragma once @@ -10,7 +10,7 @@ #include "ck/tensor_operation/gpu/device/device_grouped_gemm.hpp" #include "ck/library/tensor_operation_instance/device_operation_instance_factory.hpp" - +#ifdef CK_ENABLE_FP16 namespace ck { namespace tensor_operation { namespace device { @@ -192,3 +192,4 @@ struct DeviceOperationInstanceFactory +#include +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/device_grouped_gemm_fixed_nk.hpp" + +#include "ck/library/tensor_operation_instance/device_operation_instance_factory.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +// fp16_output +void add_device_grouped_gemm_xdl_fixed_nk_bias_f16_f16_f16_mk_kn_mn_instances( + std::vector>>& instances); + +void add_device_grouped_gemm_xdl_fixed_nk_bias_f16_f16_f16_mk_nk_mn_instances( + std::vector>>& instances); + +// fp32_output +void add_device_grouped_gemm_xdl_fixed_nk_bias_f16_f16_f32_mk_kn_mn_instances( + std::vector>>& instances); + +void add_device_grouped_gemm_xdl_fixed_nk_bias_f16_f16_f32_mk_nk_mn_instances( + std::vector>>& instances); + +template +struct DeviceOperationInstanceFactory< + ck::tensor_operation::device::DeviceGroupedGemmFixedNK> +{ + using DeviceOp = DeviceGroupedGemmFixedNK; + + static auto GetInstances() + { + std::vector> op_ptrs; + + // fp16_output + if constexpr(is_same_v && is_same_v && + is_same_v) + { + if constexpr(is_same_v && is_same_v && + is_same_v) + { + add_device_grouped_gemm_xdl_fixed_nk_bias_f16_f16_f16_mk_kn_mn_instances(op_ptrs); + } + if constexpr(is_same_v && is_same_v && + is_same_v) + { + add_device_grouped_gemm_xdl_fixed_nk_bias_f16_f16_f16_mk_nk_mn_instances(op_ptrs); + } + } + + // fp32_output + if constexpr(is_same_v && is_same_v && + is_same_v) + { + if constexpr(is_same_v && is_same_v && + is_same_v) + { + add_device_grouped_gemm_xdl_fixed_nk_bias_f16_f16_f32_mk_kn_mn_instances(op_ptrs); + } + if constexpr(is_same_v && is_same_v && + is_same_v) + { + add_device_grouped_gemm_xdl_fixed_nk_bias_f16_f16_f32_mk_nk_mn_instances(op_ptrs); + } + } + return op_ptrs; + } +}; + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/include/ck/library/tensor_operation_instance/gpu/image_to_column.hpp b/library/include/ck/library/tensor_operation_instance/gpu/image_to_column.hpp new file mode 100644 index 0000000000000000000000000000000000000000..6c4526ba4e6bd893b03a517dc8f7b1e68261fa85 --- /dev/null +++ b/library/include/ck/library/tensor_operation_instance/gpu/image_to_column.hpp @@ -0,0 +1,135 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include +#include +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/device_image_to_column.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" + +#include "ck/library/tensor_operation_instance/device_operation_instance_factory.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +// nhwc, 1d +void add_device_image_to_column_nhwc_1d_bf16_instances( + std::vector>>& instances); + +void add_device_image_to_column_nhwc_1d_f16_instances( + std::vector>>& instances); + +void add_device_image_to_column_nhwc_1d_f32_instances( + std::vector>>& instances); + +void add_device_image_to_column_nhwc_1d_i8_instances( + std::vector>>& instances); +// nhwc, 2d +void add_device_image_to_column_nhwc_2d_bf16_instances( + std::vector>>& instances); + +void add_device_image_to_column_nhwc_2d_f16_instances( + std::vector>>& instances); + +void add_device_image_to_column_nhwc_2d_f32_instances( + std::vector>>& instances); + +void add_device_image_to_column_nhwc_2d_i8_instances( + std::vector>>& instances); +// nhwc, 3d +void add_device_image_to_column_nhwc_3d_bf16_instances( + std::vector>>& instances); + +void add_device_image_to_column_nhwc_3d_f16_instances( + std::vector>>& instances); + +void add_device_image_to_column_nhwc_3d_f32_instances( + std::vector>>& instances); + +void add_device_image_to_column_nhwc_3d_i8_instances( + std::vector>>& instances); + +template +struct DeviceOperationInstanceFactory< + ck::tensor_operation::device:: + DeviceImageToColumn> +{ + using DeviceOp = DeviceImageToColumn; + + static auto GetInstances() + { + std::vector> op_ptrs; + + if constexpr(NumDimSpatial == 1 && is_same_v) + { + if constexpr(is_same_v && is_same_v) + { + add_device_image_to_column_nhwc_1d_f32_instances(op_ptrs); + } + else if constexpr(is_same_v && is_same_v) + { + add_device_image_to_column_nhwc_1d_f16_instances(op_ptrs); + } + else if constexpr(is_same_v && + is_same_v) + { + add_device_image_to_column_nhwc_1d_bf16_instances(op_ptrs); + } + else if constexpr(is_same_v && is_same_v) + { + add_device_image_to_column_nhwc_1d_i8_instances(op_ptrs); + } + } + else if constexpr(NumDimSpatial == 2 && is_same_v) + { + if constexpr(is_same_v && is_same_v) + { + add_device_image_to_column_nhwc_2d_f32_instances(op_ptrs); + } + else if constexpr(is_same_v && is_same_v) + { + add_device_image_to_column_nhwc_2d_f16_instances(op_ptrs); + } + else if constexpr(is_same_v && + is_same_v) + { + add_device_image_to_column_nhwc_2d_bf16_instances(op_ptrs); + } + else if constexpr(is_same_v && is_same_v) + { + add_device_image_to_column_nhwc_2d_i8_instances(op_ptrs); + } + } + else if constexpr(NumDimSpatial == 3 && is_same_v) + { + if constexpr(is_same_v && is_same_v) + { + add_device_image_to_column_nhwc_3d_f32_instances(op_ptrs); + } + else if constexpr(is_same_v && is_same_v) + { + add_device_image_to_column_nhwc_3d_f16_instances(op_ptrs); + } + else if constexpr(is_same_v && + is_same_v) + { + add_device_image_to_column_nhwc_3d_bf16_instances(op_ptrs); + } + else if constexpr(is_same_v && is_same_v) + { + add_device_image_to_column_nhwc_3d_i8_instances(op_ptrs); + } + } + + return op_ptrs; + } +}; + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/include/ck/library/tensor_operation_instance/gpu/image_to_column/device_image_to_column_instance.hpp b/library/include/ck/library/tensor_operation_instance/gpu/image_to_column/device_image_to_column_instance.hpp new file mode 100644 index 0000000000000000000000000000000000000000..a2603218b240c4a350aed98e083903cac2740b4f --- /dev/null +++ b/library/include/ck/library/tensor_operation_instance/gpu/image_to_column/device_image_to_column_instance.hpp @@ -0,0 +1,121 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_image_to_column_impl.hpp" + +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +using namespace ck::tensor_layout::convolution; + +using BF16 = ck::bhalf_t; +using F16 = ck::half_t; +using F32 = float; + +template +using S = ck::Sequence; + +template +using device_image_to_column_bf16_instances = std::tuple< + // clang-format off + //#####################| Num| InLayout| InDataType| OutDataType| Block| MPer| KPer| Thread| Scalar| + //#####################| Dim| | | | Size| Block| Block| Cluster| Per| + //#####################| Spatial| | | | | | | Lengths| Vector| + //#####################| | | | | | | | | | + DeviceImageToColumnImpl, 1>, + DeviceImageToColumnImpl, 1>, + DeviceImageToColumnImpl, 4>, + DeviceImageToColumnImpl, 8>, + DeviceImageToColumnImpl, 1>, + DeviceImageToColumnImpl, 1>, + DeviceImageToColumnImpl, 4>, + DeviceImageToColumnImpl, 8>, + DeviceImageToColumnImpl, 1>, + DeviceImageToColumnImpl, 1>, + DeviceImageToColumnImpl, 1>, + DeviceImageToColumnImpl, 4>, + DeviceImageToColumnImpl, 4>, + DeviceImageToColumnImpl, 8> + // clang-format on + >; + +template +using device_image_to_column_f16_instances = std::tuple< + // clang-format off + //#####################| Num| InLayout| InDataType| OutDataType| Block| MPer| KPer| Thread| Scalar| + //#####################| Dim| | | | Size| Block| Block| Cluster| Per| + //#####################| Spatial| | | | | | | Lengths| Vector| + //#####################| | | | | | | | | | + DeviceImageToColumnImpl, 1>, + DeviceImageToColumnImpl, 1>, + DeviceImageToColumnImpl, 4>, + DeviceImageToColumnImpl, 8>, + DeviceImageToColumnImpl, 1>, + DeviceImageToColumnImpl, 1>, + DeviceImageToColumnImpl, 4>, + DeviceImageToColumnImpl, 8>, + DeviceImageToColumnImpl, 1>, + DeviceImageToColumnImpl, 1>, + DeviceImageToColumnImpl, 1>, + DeviceImageToColumnImpl, 4>, + DeviceImageToColumnImpl, 4>, + DeviceImageToColumnImpl, 8> + // clang-format on + >; + +template +using device_image_to_column_f32_instances = std::tuple< + // clang-format off + //#####################| Num| InLayout| InDataType| OutDataType| Block| MPer| KPer| Thread| Scalar| + //#####################| Dim| | | | Size| Block| Block| Cluster| Per| + //#####################| Spatial| | | | | | | Lengths| Vector| + //#####################| | | | | | | | | | + DeviceImageToColumnImpl, 1>, + DeviceImageToColumnImpl, 1>, + DeviceImageToColumnImpl, 4>, + DeviceImageToColumnImpl, 1>, + DeviceImageToColumnImpl, 1>, + DeviceImageToColumnImpl, 4>, + DeviceImageToColumnImpl, 1>, + DeviceImageToColumnImpl, 1>, + DeviceImageToColumnImpl, 1>, + DeviceImageToColumnImpl, 4>, + DeviceImageToColumnImpl, 4> + // clang-format on + >; + +template +using device_image_to_column_i8_instances = std::tuple< + // clang-format off + //#####################| Num| InLayout| InDataType| OutDataType| Block| MPer| KPer| Thread| Scalar| + //#####################| Dim| | | | Size| Block| Block| Cluster| Per| + //#####################| Spatial| | | | | | | Lengths| Vector| + //#####################| | | | | | | | | | + DeviceImageToColumnImpl, 1>, + DeviceImageToColumnImpl, 1>, + DeviceImageToColumnImpl, 4>, + DeviceImageToColumnImpl, 8>, + DeviceImageToColumnImpl, 1>, + DeviceImageToColumnImpl, 1>, + DeviceImageToColumnImpl, 4>, + DeviceImageToColumnImpl, 8>, + DeviceImageToColumnImpl, 1>, + DeviceImageToColumnImpl, 1>, + DeviceImageToColumnImpl, 1>, + DeviceImageToColumnImpl, 4>, + DeviceImageToColumnImpl, 4>, + DeviceImageToColumnImpl, 8>, + DeviceImageToColumnImpl, 16> + // clang-format on + >; + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/include/ck/library/tensor_operation_instance/gpu/max_pool_bwd.hpp b/library/include/ck/library/tensor_operation_instance/gpu/max_pool_bwd.hpp new file mode 100644 index 0000000000000000000000000000000000000000..63ea4f2891ad4976984b6373a1c4fb53f5a0a7a5 --- /dev/null +++ b/library/include/ck/library/tensor_operation_instance/gpu/max_pool_bwd.hpp @@ -0,0 +1,58 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck/tensor_operation/gpu/device/device_max_pool_bwd.hpp" +#include "ck/library/tensor_operation_instance/device_operation_instance_factory.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +#ifdef CK_ENABLE_FP16 +void add_device_maxpool_bwd_f16_instances( + std::vector>>&); +#endif +#ifdef CK_ENABLE_BF16 +void add_device_maxpool_bwd_bf16_instances( + std::vector>>&); +#endif +#ifdef CK_ENABLE_FP32 +void add_device_maxpool_bwd_f32_instances( + std::vector>>&); +#endif +template +struct DeviceOperationInstanceFactory< + ck::tensor_operation::device::DeviceMaxPoolBwd> +{ + using DeviceOp = DeviceMaxPoolBwd; + + static auto GetInstances() + { + std::vector> op_ptrs; +#ifdef CK_ENABLE_FP16 + if constexpr(is_same_v && is_same_v && + is_same_v) + add_device_maxpool_bwd_f16_instances(op_ptrs); +#endif +#ifdef CK_ENABLE_BF16 + else if constexpr(is_same_v && is_same_v && + is_same_v) + add_device_maxpool_bwd_bf16_instances(op_ptrs); +#endif +#ifdef CK_ENABLE_FP32 + else if constexpr(is_same_v && is_same_v && + is_same_v) + add_device_maxpool_bwd_f32_instances(op_ptrs); +#endif + + return op_ptrs; + } +}; + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/include/ck/library/tensor_operation_instance/gpu/normalization.hpp b/library/include/ck/library/tensor_operation_instance/gpu/normalization.hpp index 199ed73b4c233fd83c747f2abec60b977b31c1ae..8e90a7ea983c99611129c9769d0651bed8889748 100644 --- a/library/include/ck/library/tensor_operation_instance/gpu/normalization.hpp +++ b/library/include/ck/library/tensor_operation_instance/gpu/normalization.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #pragma once @@ -16,7 +16,7 @@ namespace ck { namespace tensor_operation { namespace device { namespace instance { - +#ifdef CK_ENABLE_FP16 // FP16 void add_device_normalization_rank_2_1_f16_instances( std::vector>>&); @@ -26,7 +26,8 @@ void add_device_normalization_rank_4_3_f16_instances( void add_device_normalization_rank_5_3_f16_instances( std::vector>>&); - +#endif +#ifdef CK_ENABLE_FP32 // FP32 void add_device_normalization_rank_2_1_f32_instances( std::vector>>&); @@ -36,7 +37,7 @@ void add_device_normalization_rank_4_3_f32_instances( void add_device_normalization_rank_5_3_f32_instances( std::vector>>&); - +#endif template > op_ptrs; - +#ifdef CK_ENABLE_FP16 if constexpr(is_same_v && is_same_v && is_same_v && is_same_v) { @@ -82,8 +83,10 @@ struct DeviceOperationInstanceFactory && is_same_v && - is_same_v && is_same_v) +#endif +#ifdef CK_ENABLE_FP32 + if constexpr(is_same_v && is_same_v && + is_same_v && is_same_v) { if constexpr(Rank == 2 && NumReduceDim == 1) { @@ -98,7 +101,7 @@ struct DeviceOperationInstanceFactory - -#include "ck/ck.hpp" -#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" -#include "ck/tensor_operation/gpu/device/device_pool_fwd.hpp" -#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" - -#include "ck/library/tensor_operation_instance/device_operation_instance_factory.hpp" - -namespace ck { -namespace tensor_operation { -namespace device { -namespace instance { - -static constexpr auto InOutRank = 4; -static constexpr auto WindowRank = 2; - -static constexpr auto MaxOp = ck::ReduceTensorOp::MAX; -static constexpr auto AvgOp = ck::ReduceTensorOp::AVG; - -// FP16 -void add_device_pool2d_fwd_nhwc_f16_instances( - std::vector< - std::unique_ptr>>&); - -void add_device_pool2d_fwd_nhwc_f16_instances( - std::vector< - std::unique_ptr>>&); - -// FP16 - return index -void add_device_pool2d_fwd_nhwc_index_f16_instances( - std::vector< - std::unique_ptr>>&); - -// FP32 -void add_device_pool2d_fwd_nhwc_f32_instances( - std::vector< - std::unique_ptr>>&); - -void add_device_pool2d_fwd_nhwc_f32_instances( - std::vector< - std::unique_ptr>>&); - -// FP32 - return index -void add_device_pool2d_fwd_nhwc_index_f32_instances( - std::vector< - std::unique_ptr>>&); - -template -struct DeviceOperationInstanceFactory> -{ - using DeviceOp = DevicePoolFwd; - - static auto GetInstances() - { - std::vector> op_ptrs; - - if constexpr(is_same_v && is_same_v && - is_same_v) - { - if constexpr(OutputIndex && ReduceOpId == MaxOp) - { - add_device_pool2d_fwd_nhwc_index_f16_instances(op_ptrs); - } - else - { - add_device_pool2d_fwd_nhwc_f16_instances(op_ptrs); - } - } - else if constexpr(is_same_v && is_same_v && - is_same_v) - { - if constexpr(OutputIndex && ReduceOpId == MaxOp) - { - add_device_pool2d_fwd_nhwc_index_f32_instances(op_ptrs); - } - else - { - add_device_pool2d_fwd_nhwc_f32_instances(op_ptrs); - } - } - - return op_ptrs; - } -}; - -} // namespace instance -} // namespace device -} // namespace tensor_operation -} // namespace ck diff --git a/library/include/ck/library/tensor_operation_instance/gpu/pool3d_fwd.hpp b/library/include/ck/library/tensor_operation_instance/gpu/pool3d_fwd.hpp index 88523c70399854b2fb575e8476cba795b4c300d6..94ee68a409d6806ac0671f529bb4977ef367ed90 100644 --- a/library/include/ck/library/tensor_operation_instance/gpu/pool3d_fwd.hpp +++ b/library/include/ck/library/tensor_operation_instance/gpu/pool3d_fwd.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #pragma once @@ -22,38 +22,56 @@ static constexpr auto WindowRank = 3; static constexpr auto MaxOp = ck::ReduceTensorOp::MAX; static constexpr auto AvgOp = ck::ReduceTensorOp::AVG; - +#ifdef CK_ENABLE_FP16 // FP16 void add_device_pool3d_fwd_ndhwc_f16_instances( - std::vector< - std::unique_ptr>>&); + std::vector>>&); void add_device_pool3d_fwd_ndhwc_f16_instances( - std::vector< - std::unique_ptr>>&); + std::vector>>&); // FP16 - return index void add_device_pool3d_fwd_ndhwc_index_f16_instances( - std::vector< - std::unique_ptr>>&); - + std::vector>>&); +#endif +#ifdef CK_ENABLE_BF16 +// BF16 +void add_device_pool3d_fwd_ndhwc_bf16_instances( + std::vector>>&); + +void add_device_pool3d_fwd_ndhwc_bf16_instances( + std::vector>>&); + +// BF16 - return index +void add_device_pool3d_fwd_ndhwc_index_bf16_instances( + std::vector>>&); +#endif +#ifdef CK_ENABLE_FP32 // FP32 void add_device_pool3d_fwd_ndhwc_f32_instances( - std::vector< - std::unique_ptr>>&); + std::vector>>&); void add_device_pool3d_fwd_ndhwc_f32_instances( - std::vector< - std::unique_ptr>>&); + std::vector>>&); // FP32 - return index void add_device_pool3d_fwd_ndhwc_index_f32_instances( - std::vector< - std::unique_ptr>>&); - + std::vector>>&); +#endif template struct DeviceOperationInstanceFactory> { @@ -69,36 +89,58 @@ struct DeviceOperationInstanceFactory; static auto GetInstances() { std::vector> op_ptrs; - - if constexpr(is_same_v && is_same_v && - is_same_v) + if constexpr(is_same_v && is_same_v) { - if constexpr(OutputIndex && ReduceOpId == MaxOp) +#ifdef CK_ENABLE_FP16 + if constexpr(is_same_v && is_same_v && + is_same_v) { - add_device_pool3d_fwd_ndhwc_index_f16_instances(op_ptrs); + if constexpr(OutputIndex && ReduceOpId == MaxOp) + { + add_device_pool3d_fwd_ndhwc_index_f16_instances(op_ptrs); + } + else + { + add_device_pool3d_fwd_ndhwc_f16_instances(op_ptrs); + } } - else - { - add_device_pool3d_fwd_ndhwc_f16_instances(op_ptrs); - } - } - else if constexpr(is_same_v && is_same_v && - is_same_v) - { - if constexpr(OutputIndex && ReduceOpId == MaxOp) +#endif +#ifdef CK_ENABLE_BF16 + else if constexpr(is_same_v && is_same_v && + is_same_v) { - add_device_pool3d_fwd_ndhwc_index_f32_instances(op_ptrs); + if constexpr(OutputIndex && ReduceOpId == MaxOp) + { + add_device_pool3d_fwd_ndhwc_index_bf16_instances(op_ptrs); + } + else + { + add_device_pool3d_fwd_ndhwc_bf16_instances(op_ptrs); + } } - else +#endif +#ifdef CK_ENABLE_FP32 + else if constexpr(is_same_v && is_same_v && + is_same_v) { - add_device_pool3d_fwd_ndhwc_f32_instances(op_ptrs); + if constexpr(OutputIndex && ReduceOpId == MaxOp) + { + add_device_pool3d_fwd_ndhwc_index_f32_instances(op_ptrs); + } + else + { + add_device_pool3d_fwd_ndhwc_f32_instances(op_ptrs); + } } +#endif } return op_ptrs; diff --git a/library/include/ck/library/tensor_operation_instance/gpu/quantization/gemm_quantization.hpp b/library/include/ck/library/tensor_operation_instance/gpu/quantization/gemm_quantization.hpp index 2fd7ce22f7010f9e7b738395b5f670ac14aa9924..19600a90f89cade9ea42000f7b3a69b4b7544f98 100644 --- a/library/include/ck/library/tensor_operation_instance/gpu/quantization/gemm_quantization.hpp +++ b/library/include/ck/library/tensor_operation_instance/gpu/quantization/gemm_quantization.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #pragma once @@ -11,12 +11,12 @@ #include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" #include "ck/library/tensor_operation_instance/device_operation_instance_factory.hpp" - +#ifdef CK_ENABLE_INT8 namespace ck { namespace tensor_operation { namespace device { namespace instance { - +#ifdef DL_KERNELS // Layout(A, B, C) = [Col, Row, Row] void add_device_gemm_quantization_dl_c_shuffle_i8_i8_i8_km_kn_mn_instances( std::vector>>>& instances); - +#endif // Layout(A, B, C) = [Col, Row, Row] void add_device_gemm_quantization_xdl_c_shuffle_i8_i8_i8_km_kn_mn_instances( std::vector) { +#ifdef DL_KERNELS add_device_gemm_quantization_dl_c_shuffle_i8_i8_i8_mk_kn_mn_instances(op_ptrs); +#endif add_device_gemm_quantization_xdl_c_shuffle_i8_i8_i8_mk_kn_mn_instances(op_ptrs); } } @@ -190,7 +192,9 @@ struct DeviceOperationInstanceFactory) { +#ifdef DL_KERNELS add_device_gemm_quantization_dl_c_shuffle_i8_i8_i8_mk_nk_mn_instances(op_ptrs); +#endif add_device_gemm_quantization_xdl_c_shuffle_i8_i8_i8_mk_nk_mn_instances(op_ptrs); } } @@ -199,7 +203,9 @@ struct DeviceOperationInstanceFactory) { +#ifdef DL_KERNELS add_device_gemm_quantization_dl_c_shuffle_i8_i8_i8_km_kn_mn_instances(op_ptrs); +#endif add_device_gemm_quantization_xdl_c_shuffle_i8_i8_i8_km_kn_mn_instances(op_ptrs); } } @@ -208,7 +214,9 @@ struct DeviceOperationInstanceFactory) { +#ifdef DL_KERNELS add_device_gemm_quantization_dl_c_shuffle_i8_i8_i8_km_nk_mn_instances(op_ptrs); +#endif add_device_gemm_quantization_xdl_c_shuffle_i8_i8_i8_km_nk_mn_instances(op_ptrs); } } @@ -222,3 +230,4 @@ struct DeviceOperationInstanceFactory>>>& instances); - +#endif void add_device_conv2d_xdl_bias_perchannel_quantization_int8_instances( std::vector< std::unique_ptr) { +#ifdef DL_KERNELS add_device_conv2d_dl_bias_perchannel_quantization_int8_instances(op_ptrs); +#endif add_device_conv2d_xdl_bias_perchannel_quantization_int8_instances(op_ptrs); } else if constexpr(is_same_v) { +#ifdef DL_KERNELS add_device_conv2d_dl_bias_relu_perchannel_quantization_int8_instances(op_ptrs); +#endif add_device_conv2d_xdl_bias_relu_perchannel_quantization_int8_instances(op_ptrs); } } @@ -229,7 +233,9 @@ struct DeviceOperationInstanceFactory) { +#ifdef DL_KERNELS add_device_conv2d_dl_bias_tanh_perchannel_quantization_int8_instances(op_ptrs); +#endif add_device_conv2d_xdl_bias_tanh_perchannel_quantization_int8_instances(op_ptrs); } } @@ -243,3 +249,4 @@ struct DeviceOperationInstanceFactory>>>& instances); - +#endif void add_device_conv2d_xdl_bias_perlayer_quantization_int8_instances( std::vector< std::unique_ptr) { +#ifdef DL_KERNELS add_device_conv2d_dl_bias_perlayer_quantization_int8_instances(op_ptrs); +#endif add_device_conv2d_xdl_bias_perlayer_quantization_int8_instances(op_ptrs); } else if constexpr(is_same_v) { +#ifdef DL_KERNELS add_device_conv2d_dl_bias_relu_perlayer_quantization_int8_instances(op_ptrs); +#endif add_device_conv2d_xdl_bias_relu_perlayer_quantization_int8_instances(op_ptrs); } } @@ -227,7 +231,9 @@ struct DeviceOperationInstanceFactory) { +#ifdef DL_KERNELS add_device_conv2d_dl_bias_tanh_perlayer_quantization_int8_instances(op_ptrs); +#endif add_device_conv2d_xdl_bias_tanh_perlayer_quantization_int8_instances(op_ptrs); } } @@ -241,3 +247,4 @@ struct DeviceOperationInstanceFactory>>>& instances); - +#endif void add_device_conv2d_xdl_perchannel_quantization_int8_instances( std::vector) { +#ifdef DL_KERNELS add_device_conv2d_dl_perchannel_quantization_int8_instances(op_ptrs); +#endif add_device_conv2d_xdl_perchannel_quantization_int8_instances(op_ptrs); } else if constexpr(is_same_v) { +#ifdef DL_KERNELS add_device_conv2d_dl_relu_perchannel_quantization_int8_instances(op_ptrs); +#endif add_device_conv2d_xdl_relu_perchannel_quantization_int8_instances(op_ptrs); } } @@ -147,3 +151,4 @@ struct DeviceOperationInstanceFactory>>>& instances); - +#endif void add_device_conv2d_xdl_perlayer_quantization_int8_instances( std::vector) { +#ifdef DL_KERNELS add_device_conv2d_dl_perlayer_quantization_int8_instances(op_ptrs); +#endif add_device_conv2d_xdl_perlayer_quantization_int8_instances(op_ptrs); } else if constexpr(is_same_v) { +#ifdef DL_KERNELS add_device_conv2d_dl_relu_perlayer_quantization_int8_instances(op_ptrs); +#endif add_device_conv2d_xdl_relu_perlayer_quantization_int8_instances(op_ptrs); } } @@ -144,3 +148,4 @@ struct DeviceOperationInstanceFactory::value, 1>{}( [&](auto i) { - using cfg1 = remove_cvref_t(reduce_configuration_1_instances_blockwise{}))>; + using cfg1 = remove_cvref_t( + reduce_configuration_1_instances_blockwise{}))>; static_for<0, std::tuple_size::value, 1>{}( [&](auto j) { - using cfg2 = remove_cvref_t(reduce_configuration_2_instances_blockwise{}))>; + using cfg2 = remove_cvref_t( + reduce_configuration_2_instances_blockwise{}))>; using ReduceOpInstance = DeviceReduceMultiBlock::value, 1>{}([&](auto i) { - using cfg1 = remove_cvref_t(reduce_configuration_1_instances_multiblock_atomic_add{}))>; + using cfg1 = remove_cvref_t( + reduce_configuration_1_instances_multiblock_atomic_add{}))>; static_for<0, std::tuple_size::value, 1>{}([&](auto j) { - using cfg2 = remove_cvref_t(reduce_configuration_2_instances_multiblock_atomic_add{}))>; + using cfg2 = remove_cvref_t( + reduce_configuration_2_instances_multiblock_atomic_add{}))>; using ReduceOpInstance = DeviceReduceMultiBlock::value, 1>{}( [&](auto j) { - using cfg2 = remove_cvref_t(reduce_configuration_2_instances_threadwise{}))>; + using cfg2 = remove_cvref_t( + reduce_configuration_2_instances_threadwise{}))>; using ReduceOpInstance = DeviceReduceThreadWise>&); -void add_device_softmax_f16_f16_rank4_instances( - std::vector>&); - -void add_device_softmax_f32_f32_rank3_instances( - std::vector>&); -void add_device_softmax_f32_f32_rank4_instances( - std::vector>&); - -void add_device_softmax_i8_i8_rank3_instances( - std::vector>&); -void add_device_softmax_i8_i8_rank4_instances( - std::vector>&); - -template -struct DeviceOperationInstanceFactory< - ck::tensor_operation::device:: - DeviceSoftmax> +template +struct DeviceOperationInstanceFactory> { - using DeviceOp = - DeviceSoftmax; + using DeviceOp = DeviceSoftmax; static auto GetInstances() { std::vector> op_ptrs; - +#ifdef CK_ENABLE_FP16 if constexpr(std::is_same_v && std::is_same_v && std::is_same_v) { if constexpr(Rank == 3) - add_device_softmax_f16_f16_rank3_instances(op_ptrs); + { + if constexpr(NumReduceDim == 1) + add_device_softmax_f16_f16_rank3_reduce1_instances(op_ptrs); + else if constexpr(NumReduceDim == 2) + add_device_softmax_f16_f16_rank3_reduce2_instances(op_ptrs); + else if constexpr(NumReduceDim == 3) + add_device_softmax_f16_f16_rank3_reduce3_instances(op_ptrs); + } else if constexpr(Rank == 4) - add_device_softmax_f16_f16_rank4_instances(op_ptrs); + { + if constexpr(NumReduceDim == 1) + add_device_softmax_f16_f16_rank4_reduce1_instances(op_ptrs); + else if constexpr(NumReduceDim == 2) + add_device_softmax_f16_f16_rank4_reduce2_instances(op_ptrs); + else if constexpr(NumReduceDim == 3) + add_device_softmax_f16_f16_rank4_reduce3_instances(op_ptrs); + else if constexpr(NumReduceDim == 4) + add_device_softmax_f16_f16_rank4_reduce4_instances(op_ptrs); + } } - else if constexpr(std::is_same_v && std::is_same_v && - std::is_same_v) +#endif +#ifdef CK_ENABLE_FP32 + if constexpr(std::is_same_v && std::is_same_v && + std::is_same_v) { if constexpr(Rank == 3) - add_device_softmax_f32_f32_rank3_instances(op_ptrs); + { + if constexpr(NumReduceDim == 1) + add_device_softmax_f32_f32_rank3_reduce1_instances(op_ptrs); + else if constexpr(NumReduceDim == 2) + add_device_softmax_f32_f32_rank3_reduce2_instances(op_ptrs); + else if constexpr(NumReduceDim == 3) + add_device_softmax_f32_f32_rank3_reduce3_instances(op_ptrs); + } else if constexpr(Rank == 4) - add_device_softmax_f32_f32_rank4_instances(op_ptrs); + { + if constexpr(NumReduceDim == 1) + add_device_softmax_f32_f32_rank4_reduce1_instances(op_ptrs); + else if constexpr(NumReduceDim == 2) + add_device_softmax_f32_f32_rank4_reduce2_instances(op_ptrs); + else if constexpr(NumReduceDim == 3) + add_device_softmax_f32_f32_rank4_reduce3_instances(op_ptrs); + else if constexpr(NumReduceDim == 4) + add_device_softmax_f32_f32_rank4_reduce4_instances(op_ptrs); + } } - else if constexpr(std::is_same_v && std::is_same_v && - std::is_same_v) - { - if constexpr(Rank == 3) - add_device_softmax_i8_i8_rank3_instances(op_ptrs); - else if constexpr(Rank == 4) - add_device_softmax_i8_i8_rank4_instances(op_ptrs); - } - +#endif return op_ptrs; } }; diff --git a/library/include/ck/library/tensor_operation_instance/gpu/softmax/device_softmax_f16_f16_instance.hpp b/library/include/ck/library/tensor_operation_instance/gpu/softmax/device_softmax_f16_f16_instance.hpp deleted file mode 100644 index 83f52fc3ee7a9547161ce1284c2844976819c0ec..0000000000000000000000000000000000000000 --- a/library/include/ck/library/tensor_operation_instance/gpu/softmax/device_softmax_f16_f16_instance.hpp +++ /dev/null @@ -1,22 +0,0 @@ -// SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. - -#pragma once - -#include "ck/library/tensor_operation_instance/device_operation_instance_factory.hpp" -#include "ck/tensor_operation/gpu/device/device_softmax.hpp" - -namespace ck { -namespace tensor_operation { -namespace device { -namespace instance { - -void add_device_softmax_f16_f16_rank3_instances( - std::vector>& instances); -void add_device_softmax_f16_f16_rank4_instances( - std::vector>& instances); - -} // namespace instance -} // namespace device -} // namespace tensor_operation -} // namespace ck diff --git a/library/include/ck/library/tensor_operation_instance/gpu/softmax/device_softmax_f16_f16_instance_rank3_reduce1.hpp b/library/include/ck/library/tensor_operation_instance/gpu/softmax/device_softmax_f16_f16_instance_rank3_reduce1.hpp index 046ff5780556331a6afc697f87dbbf02dcf1ff25..3fd2bd089ed8fdffa2d40c7b6b10ae146322ee6b 100644 --- a/library/include/ck/library/tensor_operation_instance/gpu/softmax/device_softmax_f16_f16_instance_rank3_reduce1.hpp +++ b/library/include/ck/library/tensor_operation_instance/gpu/softmax/device_softmax_f16_f16_instance_rank3_reduce1.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #pragma once @@ -14,7 +14,7 @@ namespace device { namespace instance { void add_device_softmax_f16_f16_rank3_reduce1_instances( - std::vector>& instances); + std::vector>& instances); } // namespace instance } // namespace device diff --git a/library/include/ck/library/tensor_operation_instance/gpu/softmax/device_softmax_f16_f16_instance_rank3_reduce2.hpp b/library/include/ck/library/tensor_operation_instance/gpu/softmax/device_softmax_f16_f16_instance_rank3_reduce2.hpp index 8e6a226f6a13fb1dcd7d61b56b476f38072b6a98..210fdc0a58548f8388602cb960ae0b3c019e05ba 100644 --- a/library/include/ck/library/tensor_operation_instance/gpu/softmax/device_softmax_f16_f16_instance_rank3_reduce2.hpp +++ b/library/include/ck/library/tensor_operation_instance/gpu/softmax/device_softmax_f16_f16_instance_rank3_reduce2.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #pragma once @@ -14,7 +14,7 @@ namespace device { namespace instance { void add_device_softmax_f16_f16_rank3_reduce2_instances( - std::vector>& instances); + std::vector>& instances); } // namespace instance } // namespace device diff --git a/library/include/ck/library/tensor_operation_instance/gpu/softmax/device_softmax_f16_f16_instance_rank3_reduce3.hpp b/library/include/ck/library/tensor_operation_instance/gpu/softmax/device_softmax_f16_f16_instance_rank3_reduce3.hpp index 518fa5f98679099189f15dfe5a57e3ac0324e3fa..894fb034d0dfa86ec85d76b7f015662af1d775d1 100644 --- a/library/include/ck/library/tensor_operation_instance/gpu/softmax/device_softmax_f16_f16_instance_rank3_reduce3.hpp +++ b/library/include/ck/library/tensor_operation_instance/gpu/softmax/device_softmax_f16_f16_instance_rank3_reduce3.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #pragma once @@ -14,7 +14,7 @@ namespace device { namespace instance { void add_device_softmax_f16_f16_rank3_reduce3_instances( - std::vector>& instances); + std::vector>& instances); } // namespace instance } // namespace device diff --git a/library/include/ck/library/tensor_operation_instance/gpu/softmax/device_softmax_f16_f16_instance_rank4_reduce1.hpp b/library/include/ck/library/tensor_operation_instance/gpu/softmax/device_softmax_f16_f16_instance_rank4_reduce1.hpp index 10016cdd707cbcb46282e540e15e6b79a67c6a78..708ef0ce130e1095c4c88b520d1dbdbcabf23685 100644 --- a/library/include/ck/library/tensor_operation_instance/gpu/softmax/device_softmax_f16_f16_instance_rank4_reduce1.hpp +++ b/library/include/ck/library/tensor_operation_instance/gpu/softmax/device_softmax_f16_f16_instance_rank4_reduce1.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #pragma once @@ -14,7 +14,7 @@ namespace device { namespace instance { void add_device_softmax_f16_f16_rank4_reduce1_instances( - std::vector>& instances); + std::vector>& instances); } // namespace instance } // namespace device diff --git a/library/include/ck/library/tensor_operation_instance/gpu/softmax/device_softmax_f16_f16_instance_rank4_reduce2.hpp b/library/include/ck/library/tensor_operation_instance/gpu/softmax/device_softmax_f16_f16_instance_rank4_reduce2.hpp index cdd5a3cd7b6fb6806db3de0a10b7ed87b29c5c2d..6754e5ceffa02d4bcad7919c7ea25b0dd85c0cc4 100644 --- a/library/include/ck/library/tensor_operation_instance/gpu/softmax/device_softmax_f16_f16_instance_rank4_reduce2.hpp +++ b/library/include/ck/library/tensor_operation_instance/gpu/softmax/device_softmax_f16_f16_instance_rank4_reduce2.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #pragma once @@ -14,7 +14,7 @@ namespace device { namespace instance { void add_device_softmax_f16_f16_rank4_reduce2_instances( - std::vector>& instances); + std::vector>& instances); } // namespace instance } // namespace device diff --git a/library/include/ck/library/tensor_operation_instance/gpu/softmax/device_softmax_f16_f16_instance_rank4_reduce3.hpp b/library/include/ck/library/tensor_operation_instance/gpu/softmax/device_softmax_f16_f16_instance_rank4_reduce3.hpp index a8be272e02014495cb42d0509b50075f2d7da707..5e111176e198d77b8e0e426f2a578cd9c3dd2b3f 100644 --- a/library/include/ck/library/tensor_operation_instance/gpu/softmax/device_softmax_f16_f16_instance_rank4_reduce3.hpp +++ b/library/include/ck/library/tensor_operation_instance/gpu/softmax/device_softmax_f16_f16_instance_rank4_reduce3.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #pragma once @@ -14,7 +14,7 @@ namespace device { namespace instance { void add_device_softmax_f16_f16_rank4_reduce3_instances( - std::vector>& instances); + std::vector>& instances); } // namespace instance } // namespace device diff --git a/library/include/ck/library/tensor_operation_instance/gpu/softmax/device_softmax_f16_f16_instance_rank4_reduce4.hpp b/library/include/ck/library/tensor_operation_instance/gpu/softmax/device_softmax_f16_f16_instance_rank4_reduce4.hpp index ec8296ff22fb6591152a3dcfad63c250c70cabb8..a3cecb32f83f707f54a53b7ed107f2662f4a2f08 100644 --- a/library/include/ck/library/tensor_operation_instance/gpu/softmax/device_softmax_f16_f16_instance_rank4_reduce4.hpp +++ b/library/include/ck/library/tensor_operation_instance/gpu/softmax/device_softmax_f16_f16_instance_rank4_reduce4.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #pragma once @@ -14,7 +14,7 @@ namespace device { namespace instance { void add_device_softmax_f16_f16_rank4_reduce4_instances( - std::vector>& instances); + std::vector>& instances); } // namespace instance } // namespace device diff --git a/library/include/ck/library/tensor_operation_instance/gpu/softmax/device_softmax_f16_f16_instance_type.hpp b/library/include/ck/library/tensor_operation_instance/gpu/softmax/device_softmax_f16_f16_instance_type.hpp index b3877c4bb3f6217427f69c07f7073751b254c938..8c0782daa556168ac359f68d301080720c622e1b 100644 --- a/library/include/ck/library/tensor_operation_instance/gpu/softmax/device_softmax_f16_f16_instance_type.hpp +++ b/library/include/ck/library/tensor_operation_instance/gpu/softmax/device_softmax_f16_f16_instance_type.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #include @@ -16,7 +16,6 @@ template using device_softmax_f16_f16_instances = std::tuple< // clang-format off // InDataType, AccDataType, OutDataType, InElementwiseOp, AccElementwiseOp, Rank, NumReduceDim, BlockSize, MThreadClusterSize, KThreadClusterSize, MThreadSliceSize, KThreadSliceSize, InSrcVectorDim, InSrcVectorSize, OutDstVectorSize> - // fallback kernel DeviceSoftmaxImpl< F16, F32, F16, PassThrough, PassThrough, Rank, Reduce, 256, 8, 32, 1, 8, 1, 1, 1>, DeviceSoftmaxImpl< F16, F32, F16, PassThrough, PassThrough, Rank, Reduce, 256, 8, 32, 1, 8, 1, 8, 8>, DeviceSoftmaxImpl< F16, F32, F16, PassThrough, PassThrough, Rank, Reduce, 256, 4, 64, 1, 8, 1, 8, 8>, @@ -33,6 +32,13 @@ using device_softmax_f16_f16_instances = std::tuple< // clang-format on >; +template +using device_softmax_f16_f16_generic_instance = std::tuple< + // clang-format off + DeviceSoftmaxImpl< F16, F32, F16, PassThrough, PassThrough, Rank, Reduce, 64, 8, 8, 1, 1, 1, 1, 1> + // clang-format on + >; + } // namespace instance } // namespace device } // namespace tensor_operation diff --git a/library/include/ck/library/tensor_operation_instance/gpu/softmax/device_softmax_f32_f32_instance.hpp b/library/include/ck/library/tensor_operation_instance/gpu/softmax/device_softmax_f32_f32_instance.hpp deleted file mode 100644 index a6d9a359f4622ec9534ab090d46e363bd9101177..0000000000000000000000000000000000000000 --- a/library/include/ck/library/tensor_operation_instance/gpu/softmax/device_softmax_f32_f32_instance.hpp +++ /dev/null @@ -1,22 +0,0 @@ -// SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. - -#pragma once - -#include "ck/library/tensor_operation_instance/device_operation_instance_factory.hpp" -#include "ck/tensor_operation/gpu/device/device_softmax.hpp" - -namespace ck { -namespace tensor_operation { -namespace device { -namespace instance { - -void add_device_softmax_f32_f32_rank3_instances( - std::vector>& instances); -void add_device_softmax_f32_f32_rank4_instances( - std::vector>& instances); - -} // namespace instance -} // namespace device -} // namespace tensor_operation -} // namespace ck diff --git a/library/include/ck/library/tensor_operation_instance/gpu/softmax/device_softmax_f32_f32_instance_rank3_reduce1.hpp b/library/include/ck/library/tensor_operation_instance/gpu/softmax/device_softmax_f32_f32_instance_rank3_reduce1.hpp index 6621a2c867ac90402a853e5c2065663a96220927..4cc469025335831b8c502e0693e92894ba454202 100644 --- a/library/include/ck/library/tensor_operation_instance/gpu/softmax/device_softmax_f32_f32_instance_rank3_reduce1.hpp +++ b/library/include/ck/library/tensor_operation_instance/gpu/softmax/device_softmax_f32_f32_instance_rank3_reduce1.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #pragma once @@ -14,7 +14,7 @@ namespace device { namespace instance { void add_device_softmax_f32_f32_rank3_reduce1_instances( - std::vector>& instances); + std::vector>& instances); } // namespace instance } // namespace device diff --git a/library/include/ck/library/tensor_operation_instance/gpu/softmax/device_softmax_f32_f32_instance_rank3_reduce2.hpp b/library/include/ck/library/tensor_operation_instance/gpu/softmax/device_softmax_f32_f32_instance_rank3_reduce2.hpp index 3dfac98ed8beeee6060b6ba3f6b9b1ff19f660fe..65724d7888ac9801cc2dcd97ce56dd8681ca6cfe 100644 --- a/library/include/ck/library/tensor_operation_instance/gpu/softmax/device_softmax_f32_f32_instance_rank3_reduce2.hpp +++ b/library/include/ck/library/tensor_operation_instance/gpu/softmax/device_softmax_f32_f32_instance_rank3_reduce2.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #pragma once @@ -14,7 +14,7 @@ namespace device { namespace instance { void add_device_softmax_f32_f32_rank3_reduce2_instances( - std::vector>& instances); + std::vector>& instances); } // namespace instance } // namespace device diff --git a/library/include/ck/library/tensor_operation_instance/gpu/softmax/device_softmax_f32_f32_instance_rank3_reduce3.hpp b/library/include/ck/library/tensor_operation_instance/gpu/softmax/device_softmax_f32_f32_instance_rank3_reduce3.hpp index 6d2a0c932500e874641d8cef9fd65b4d280c89b6..13bd45598eca17c81df5b8eba6fe41f3400057b5 100644 --- a/library/include/ck/library/tensor_operation_instance/gpu/softmax/device_softmax_f32_f32_instance_rank3_reduce3.hpp +++ b/library/include/ck/library/tensor_operation_instance/gpu/softmax/device_softmax_f32_f32_instance_rank3_reduce3.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #pragma once @@ -14,7 +14,7 @@ namespace device { namespace instance { void add_device_softmax_f32_f32_rank3_reduce3_instances( - std::vector>& instances); + std::vector>& instances); } // namespace instance } // namespace device diff --git a/library/include/ck/library/tensor_operation_instance/gpu/softmax/device_softmax_f32_f32_instance_rank4_reduce1.hpp b/library/include/ck/library/tensor_operation_instance/gpu/softmax/device_softmax_f32_f32_instance_rank4_reduce1.hpp index 97dd3dcb18aed5c228313a961994096445948bbf..d58b424ee94ffad33440525e190c7adf45598eb2 100644 --- a/library/include/ck/library/tensor_operation_instance/gpu/softmax/device_softmax_f32_f32_instance_rank4_reduce1.hpp +++ b/library/include/ck/library/tensor_operation_instance/gpu/softmax/device_softmax_f32_f32_instance_rank4_reduce1.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #pragma once @@ -14,7 +14,7 @@ namespace device { namespace instance { void add_device_softmax_f32_f32_rank4_reduce1_instances( - std::vector>& instances); + std::vector>& instances); } // namespace instance } // namespace device diff --git a/library/include/ck/library/tensor_operation_instance/gpu/softmax/device_softmax_f32_f32_instance_rank4_reduce2.hpp b/library/include/ck/library/tensor_operation_instance/gpu/softmax/device_softmax_f32_f32_instance_rank4_reduce2.hpp index 58f8760acccd1b571df40ea9a08946c1f5f9764f..378e45eeb783d85a6da303349c7bd7b0329f7fc1 100644 --- a/library/include/ck/library/tensor_operation_instance/gpu/softmax/device_softmax_f32_f32_instance_rank4_reduce2.hpp +++ b/library/include/ck/library/tensor_operation_instance/gpu/softmax/device_softmax_f32_f32_instance_rank4_reduce2.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #pragma once @@ -14,7 +14,7 @@ namespace device { namespace instance { void add_device_softmax_f32_f32_rank4_reduce2_instances( - std::vector>& instances); + std::vector>& instances); } // namespace instance } // namespace device diff --git a/library/include/ck/library/tensor_operation_instance/gpu/softmax/device_softmax_f32_f32_instance_rank4_reduce3.hpp b/library/include/ck/library/tensor_operation_instance/gpu/softmax/device_softmax_f32_f32_instance_rank4_reduce3.hpp index df8d31f0da76b0f027b4ee4818006dd331e38454..293df08c7e9e8edd9a419b7e94a5d83fa46efe8f 100644 --- a/library/include/ck/library/tensor_operation_instance/gpu/softmax/device_softmax_f32_f32_instance_rank4_reduce3.hpp +++ b/library/include/ck/library/tensor_operation_instance/gpu/softmax/device_softmax_f32_f32_instance_rank4_reduce3.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #pragma once @@ -14,7 +14,7 @@ namespace device { namespace instance { void add_device_softmax_f32_f32_rank4_reduce3_instances( - std::vector>& instances); + std::vector>& instances); } // namespace instance } // namespace device diff --git a/library/include/ck/library/tensor_operation_instance/gpu/softmax/device_softmax_f32_f32_instance_rank4_reduce4.hpp b/library/include/ck/library/tensor_operation_instance/gpu/softmax/device_softmax_f32_f32_instance_rank4_reduce4.hpp index 1bd773227e172278a466d7650e90fd39ec3e56f8..e503a9fec1f5c13e71b297d985c3de8c15497758 100644 --- a/library/include/ck/library/tensor_operation_instance/gpu/softmax/device_softmax_f32_f32_instance_rank4_reduce4.hpp +++ b/library/include/ck/library/tensor_operation_instance/gpu/softmax/device_softmax_f32_f32_instance_rank4_reduce4.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #pragma once @@ -14,7 +14,7 @@ namespace device { namespace instance { void add_device_softmax_f32_f32_rank4_reduce4_instances( - std::vector>& instances); + std::vector>& instances); } // namespace instance } // namespace device diff --git a/library/include/ck/library/tensor_operation_instance/gpu/softmax/device_softmax_f32_f32_instance_type.hpp b/library/include/ck/library/tensor_operation_instance/gpu/softmax/device_softmax_f32_f32_instance_type.hpp index 16f129d2d07c19a926c75f1e878e8db3fb9bf821..90c5ddc8a01b7fd0ef58a8f744133a0635e2fe29 100644 --- a/library/include/ck/library/tensor_operation_instance/gpu/softmax/device_softmax_f32_f32_instance_type.hpp +++ b/library/include/ck/library/tensor_operation_instance/gpu/softmax/device_softmax_f32_f32_instance_type.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #include @@ -16,7 +16,7 @@ template using device_softmax_f32_f32_instances = std::tuple< // clang-format off // InDataType, AccDataType, OutDataType, InElementwiseOp, AccElementwiseOp, Rank, NumReduceDim, BlockSize, MThreadClusterSize, KThreadClusterSize, MThreadSliceSize, KThreadSliceSize, InSrcVectorDim, InSrcVectorSize, OutDstVectorSize> - DeviceSoftmaxImpl< F32, F32, F32, PassThrough, PassThrough, Rank, Reduce, 256, 8, 32, 1, 8, 1, 1, 1>, // fallback kernel + DeviceSoftmaxImpl< F32, F32, F32, PassThrough, PassThrough, Rank, Reduce, 256, 8, 32, 1, 8, 1, 1, 1>, DeviceSoftmaxImpl< F32, F32, F32, PassThrough, PassThrough, Rank, Reduce, 256, 8, 32, 1, 8, 1, 4, 4>, DeviceSoftmaxImpl< F32, F32, F32, PassThrough, PassThrough, Rank, Reduce, 256, 4, 64, 1, 8, 1, 4, 4>, DeviceSoftmaxImpl< F32, F32, F32, PassThrough, PassThrough, Rank, Reduce, 256, 2, 128, 1, 8, 1, 4, 4>, @@ -32,6 +32,13 @@ using device_softmax_f32_f32_instances = std::tuple< // clang-format on >; +template +using device_softmax_f32_f32_generic_instance = std::tuple< + // clang-format off + DeviceSoftmaxImpl< F32, F32, F32, PassThrough, PassThrough, Rank, Reduce, 64, 8, 8, 1, 1, 1, 1, 1> + // clang-format on + >; + } // namespace instance } // namespace device } // namespace tensor_operation diff --git a/library/include/ck/library/tensor_operation_instance/gpu/softmax/device_softmax_i8_i8_instance.hpp b/library/include/ck/library/tensor_operation_instance/gpu/softmax/device_softmax_i8_i8_instance.hpp deleted file mode 100644 index f80f712ff5e501d6516d9c7210a59276a96b6541..0000000000000000000000000000000000000000 --- a/library/include/ck/library/tensor_operation_instance/gpu/softmax/device_softmax_i8_i8_instance.hpp +++ /dev/null @@ -1,22 +0,0 @@ -// SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. - -#pragma once - -#include "ck/library/tensor_operation_instance/device_operation_instance_factory.hpp" -#include "ck/tensor_operation/gpu/device/device_softmax.hpp" - -namespace ck { -namespace tensor_operation { -namespace device { -namespace instance { - -void add_device_softmax_i8_i8_rank3_instances( - std::vector>& instances); -void add_device_softmax_i8_i8_rank4_instances( - std::vector>& instances); - -} // namespace instance -} // namespace device -} // namespace tensor_operation -} // namespace ck diff --git a/library/include/ck/library/tensor_operation_instance/gpu/softmax/device_softmax_i8_i8_instance_rank3_reduce1.hpp b/library/include/ck/library/tensor_operation_instance/gpu/softmax/device_softmax_i8_i8_instance_rank3_reduce1.hpp deleted file mode 100644 index 6f9952e7d58a448cfea0e3c6ba84af24e9de4e52..0000000000000000000000000000000000000000 --- a/library/include/ck/library/tensor_operation_instance/gpu/softmax/device_softmax_i8_i8_instance_rank3_reduce1.hpp +++ /dev/null @@ -1,22 +0,0 @@ -// SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. - -#pragma once - -#include - -#include "ck/library/tensor_operation_instance/device_operation_instance_factory.hpp" -#include "ck/tensor_operation/gpu/device/device_softmax.hpp" - -namespace ck { -namespace tensor_operation { -namespace device { -namespace instance { - -void add_device_softmax_i8_i8_rank3_reduce1_instances( - std::vector>& instances); - -} // namespace instance -} // namespace device -} // namespace tensor_operation -} // namespace ck diff --git a/library/include/ck/library/tensor_operation_instance/gpu/softmax/device_softmax_i8_i8_instance_rank3_reduce2.hpp b/library/include/ck/library/tensor_operation_instance/gpu/softmax/device_softmax_i8_i8_instance_rank3_reduce2.hpp deleted file mode 100644 index 2cbd13a1ba5226316363cfb07dfd23adfac4aa9f..0000000000000000000000000000000000000000 --- a/library/include/ck/library/tensor_operation_instance/gpu/softmax/device_softmax_i8_i8_instance_rank3_reduce2.hpp +++ /dev/null @@ -1,22 +0,0 @@ -// SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. - -#pragma once - -#include - -#include "ck/library/tensor_operation_instance/device_operation_instance_factory.hpp" -#include "ck/tensor_operation/gpu/device/device_softmax.hpp" - -namespace ck { -namespace tensor_operation { -namespace device { -namespace instance { - -void add_device_softmax_i8_i8_rank3_reduce2_instances( - std::vector>& instances); - -} // namespace instance -} // namespace device -} // namespace tensor_operation -} // namespace ck diff --git a/library/include/ck/library/tensor_operation_instance/gpu/softmax/device_softmax_i8_i8_instance_rank3_reduce3.hpp b/library/include/ck/library/tensor_operation_instance/gpu/softmax/device_softmax_i8_i8_instance_rank3_reduce3.hpp deleted file mode 100644 index 7b12522a85987f85e83a6f8afb34b97844507c49..0000000000000000000000000000000000000000 --- a/library/include/ck/library/tensor_operation_instance/gpu/softmax/device_softmax_i8_i8_instance_rank3_reduce3.hpp +++ /dev/null @@ -1,22 +0,0 @@ -// SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. - -#pragma once - -#include - -#include "ck/library/tensor_operation_instance/device_operation_instance_factory.hpp" -#include "ck/tensor_operation/gpu/device/device_softmax.hpp" - -namespace ck { -namespace tensor_operation { -namespace device { -namespace instance { - -void add_device_softmax_i8_i8_rank3_reduce3_instances( - std::vector>& instances); - -} // namespace instance -} // namespace device -} // namespace tensor_operation -} // namespace ck diff --git a/library/include/ck/library/tensor_operation_instance/gpu/softmax/device_softmax_i8_i8_instance_rank4_reduce1.hpp b/library/include/ck/library/tensor_operation_instance/gpu/softmax/device_softmax_i8_i8_instance_rank4_reduce1.hpp deleted file mode 100644 index 54d477f80c5b222cd5a76430dae3bbd4db81068c..0000000000000000000000000000000000000000 --- a/library/include/ck/library/tensor_operation_instance/gpu/softmax/device_softmax_i8_i8_instance_rank4_reduce1.hpp +++ /dev/null @@ -1,22 +0,0 @@ -// SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. - -#pragma once - -#include - -#include "ck/library/tensor_operation_instance/device_operation_instance_factory.hpp" -#include "ck/tensor_operation/gpu/device/device_softmax.hpp" - -namespace ck { -namespace tensor_operation { -namespace device { -namespace instance { - -void add_device_softmax_i8_i8_rank4_reduce1_instances( - std::vector>& instances); - -} // namespace instance -} // namespace device -} // namespace tensor_operation -} // namespace ck diff --git a/library/include/ck/library/tensor_operation_instance/gpu/softmax/device_softmax_i8_i8_instance_rank4_reduce2.hpp b/library/include/ck/library/tensor_operation_instance/gpu/softmax/device_softmax_i8_i8_instance_rank4_reduce2.hpp deleted file mode 100644 index 4ffc44e3a9217e689297bceb81bd6a38c41cd1c3..0000000000000000000000000000000000000000 --- a/library/include/ck/library/tensor_operation_instance/gpu/softmax/device_softmax_i8_i8_instance_rank4_reduce2.hpp +++ /dev/null @@ -1,22 +0,0 @@ -// SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. - -#pragma once - -#include - -#include "ck/library/tensor_operation_instance/device_operation_instance_factory.hpp" -#include "ck/tensor_operation/gpu/device/device_softmax.hpp" - -namespace ck { -namespace tensor_operation { -namespace device { -namespace instance { - -void add_device_softmax_i8_i8_rank4_reduce2_instances( - std::vector>& instances); - -} // namespace instance -} // namespace device -} // namespace tensor_operation -} // namespace ck diff --git a/library/include/ck/library/tensor_operation_instance/gpu/softmax/device_softmax_i8_i8_instance_rank4_reduce3.hpp b/library/include/ck/library/tensor_operation_instance/gpu/softmax/device_softmax_i8_i8_instance_rank4_reduce3.hpp deleted file mode 100644 index 08cbb81272f9ff5e7ce9a9fdd9075d81f1ed7211..0000000000000000000000000000000000000000 --- a/library/include/ck/library/tensor_operation_instance/gpu/softmax/device_softmax_i8_i8_instance_rank4_reduce3.hpp +++ /dev/null @@ -1,22 +0,0 @@ -// SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. - -#pragma once - -#include - -#include "ck/library/tensor_operation_instance/device_operation_instance_factory.hpp" -#include "ck/tensor_operation/gpu/device/device_softmax.hpp" - -namespace ck { -namespace tensor_operation { -namespace device { -namespace instance { - -void add_device_softmax_i8_i8_rank4_reduce3_instances( - std::vector>& instances); - -} // namespace instance -} // namespace device -} // namespace tensor_operation -} // namespace ck diff --git a/library/include/ck/library/tensor_operation_instance/gpu/softmax/device_softmax_i8_i8_instance_rank4_reduce4.hpp b/library/include/ck/library/tensor_operation_instance/gpu/softmax/device_softmax_i8_i8_instance_rank4_reduce4.hpp deleted file mode 100644 index 187d034b95ac506426da4fa1b8240b61e2517e46..0000000000000000000000000000000000000000 --- a/library/include/ck/library/tensor_operation_instance/gpu/softmax/device_softmax_i8_i8_instance_rank4_reduce4.hpp +++ /dev/null @@ -1,22 +0,0 @@ -// SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. - -#pragma once - -#include - -#include "ck/library/tensor_operation_instance/device_operation_instance_factory.hpp" -#include "ck/tensor_operation/gpu/device/device_softmax.hpp" - -namespace ck { -namespace tensor_operation { -namespace device { -namespace instance { - -void add_device_softmax_i8_i8_rank4_reduce4_instances( - std::vector>& instances); - -} // namespace instance -} // namespace device -} // namespace tensor_operation -} // namespace ck diff --git a/library/include/ck/library/tensor_operation_instance/gpu/softmax/device_softmax_i8_i8_instance_type.hpp b/library/include/ck/library/tensor_operation_instance/gpu/softmax/device_softmax_i8_i8_instance_type.hpp deleted file mode 100644 index 7fc9ed69198f1e7bf447a6baa9dedb3286440a94..0000000000000000000000000000000000000000 --- a/library/include/ck/library/tensor_operation_instance/gpu/softmax/device_softmax_i8_i8_instance_type.hpp +++ /dev/null @@ -1,40 +0,0 @@ -// SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. - -#include - -#include "ck/ck.hpp" -#include "ck/library/tensor_operation_instance/device_operation_instance_factory.hpp" -#include "ck/tensor_operation/gpu/device/impl/device_softmax_impl.hpp" -#include "ck/tensor_operation/gpu/element/unary_element_wise_operation.hpp" - -namespace ck { -namespace tensor_operation { -namespace device { -namespace instance { - -template -using device_softmax_i8_i8_instances = std::tuple< - // clang-format off - // InDataType, AccDataType, OutDataType, InElementwiseOp, AccElementwiseOp, Rank, NumReduceDim, BlockSize, MThreadClusterSize, KThreadClusterSize, MThreadSliceSize, KThreadSliceSize, InSrcVectorDim, InSrcVectorSize, OutDstVectorSize> - // fallback kernel - DeviceSoftmaxImpl< I8, F32, I8, PassThrough, PassThrough, Rank, Reduce, 256, 8, 32, 1, 16, 1, 1, 1>, - DeviceSoftmaxImpl< I8, F32, I8, PassThrough, PassThrough, Rank, Reduce, 256, 8, 32, 1, 16, 1, 16, 16>, - DeviceSoftmaxImpl< I8, F32, I8, PassThrough, PassThrough, Rank, Reduce, 256, 4, 64, 1, 16, 1, 16, 16>, - DeviceSoftmaxImpl< I8, F32, I8, PassThrough, PassThrough, Rank, Reduce, 256, 2, 128, 1, 16, 1, 16, 16>, - DeviceSoftmaxImpl< I8, F32, I8, PassThrough, PassThrough, Rank, Reduce, 256, 2, 128, 1, 32, 1, 16, 16>, - DeviceSoftmaxImpl< I8, F32, I8, PassThrough, PassThrough, Rank, Reduce, 256, 2, 128, 1, 64, 1, 16, 16>, - DeviceSoftmaxImpl< I8, F32, I8, PassThrough, PassThrough, Rank, Reduce, 256, 1, 256, 1, 16, 1, 16, 16>, - DeviceSoftmaxImpl< I8, F32, I8, PassThrough, PassThrough, Rank, Reduce, 256, 1, 256, 1, 32, 1, 16, 16>, - DeviceSoftmaxImpl< I8, F32, I8, PassThrough, PassThrough, Rank, Reduce, 256, 1, 256, 1, 64, 1, 16, 16>, - // Reduction on middle dimensions - // InSrcVectorDim is 0 since we want to coalesce reads on M dimension - DeviceSoftmaxImpl< I8, F32, I8, PassThrough, PassThrough, Rank, Reduce, 256, 8, 32, 8, 8, 0, 1, 1>, - DeviceSoftmaxImpl< I8, F32, I8, PassThrough, PassThrough, Rank, Reduce, 256, 32, 8, 32, 8, 0, 16, 8> - // clang-format on - >; - -} // namespace instance -} // namespace device -} // namespace tensor_operation -} // namespace ck diff --git a/library/include/ck/library/tensor_operation_instance/gpu/softmax/device_softmax_instance.hpp b/library/include/ck/library/tensor_operation_instance/gpu/softmax/device_softmax_instance.hpp index 03be6e2bc7c15ee350569f7ed97ebf2d3ad3b515..10f99acb8d6fe4c5d1fb0a065231995a3d870392 100644 --- a/library/include/ck/library/tensor_operation_instance/gpu/softmax/device_softmax_instance.hpp +++ b/library/include/ck/library/tensor_operation_instance/gpu/softmax/device_softmax_instance.hpp @@ -1,8 +1,19 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #pragma once -#include "ck/library/tensor_operation_instance/gpu/softmax/device_softmax_f16_f16_instance.hpp" -#include "ck/library/tensor_operation_instance/gpu/softmax/device_softmax_f32_f32_instance.hpp" -#include "ck/library/tensor_operation_instance/gpu/softmax/device_softmax_i8_i8_instance.hpp" +#include "ck/library/tensor_operation_instance/gpu/softmax/device_softmax_f16_f16_instance_rank3_reduce1.hpp" +#include "ck/library/tensor_operation_instance/gpu/softmax/device_softmax_f16_f16_instance_rank3_reduce2.hpp" +#include "ck/library/tensor_operation_instance/gpu/softmax/device_softmax_f16_f16_instance_rank3_reduce3.hpp" +#include "ck/library/tensor_operation_instance/gpu/softmax/device_softmax_f16_f16_instance_rank4_reduce1.hpp" +#include "ck/library/tensor_operation_instance/gpu/softmax/device_softmax_f16_f16_instance_rank4_reduce2.hpp" +#include "ck/library/tensor_operation_instance/gpu/softmax/device_softmax_f16_f16_instance_rank4_reduce3.hpp" +#include "ck/library/tensor_operation_instance/gpu/softmax/device_softmax_f16_f16_instance_rank4_reduce4.hpp" +#include "ck/library/tensor_operation_instance/gpu/softmax/device_softmax_f32_f32_instance_rank3_reduce1.hpp" +#include "ck/library/tensor_operation_instance/gpu/softmax/device_softmax_f32_f32_instance_rank3_reduce2.hpp" +#include "ck/library/tensor_operation_instance/gpu/softmax/device_softmax_f32_f32_instance_rank3_reduce3.hpp" +#include "ck/library/tensor_operation_instance/gpu/softmax/device_softmax_f32_f32_instance_rank4_reduce1.hpp" +#include "ck/library/tensor_operation_instance/gpu/softmax/device_softmax_f32_f32_instance_rank4_reduce2.hpp" +#include "ck/library/tensor_operation_instance/gpu/softmax/device_softmax_f32_f32_instance_rank4_reduce3.hpp" +#include "ck/library/tensor_operation_instance/gpu/softmax/device_softmax_f32_f32_instance_rank4_reduce4.hpp" diff --git a/library/include/ck/library/utility/algorithm.hpp b/library/include/ck/library/utility/algorithm.hpp index 86f04dd362397e774c9821a16a7c412dea000937..57136f8a2a19fbdaa2e0276699c786c33cc61473 100644 --- a/library/include/ck/library/utility/algorithm.hpp +++ b/library/include/ck/library/utility/algorithm.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/library/include/ck/library/utility/check_err.hpp b/library/include/ck/library/utility/check_err.hpp index a89d03d324f33f0be95f88c328a9d7aee38e7460..8a7263137638ce52da10caaf37cd695007932765 100644 --- a/library/include/ck/library/utility/check_err.hpp +++ b/library/include/ck/library/utility/check_err.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #pragma once @@ -65,7 +65,11 @@ check_err(const Range& out, } if(!res) { - std::cerr << std::setw(12) << std::setprecision(7) << "max err: " << max_err << std::endl; + const float error_percent = + static_cast(err_count) / static_cast(out.size()) * 100.f; + std::cerr << "max err: " << max_err; + std::cerr << ", number of errors: " << err_count; + std::cerr << ", " << error_percent << "% wrong values" << std::endl; } return res; } @@ -112,7 +116,11 @@ check_err(const Range& out, } if(!res) { - std::cerr << std::setw(12) << std::setprecision(7) << "max err: " << max_err << std::endl; + const float error_percent = + static_cast(err_count) / static_cast(out.size()) * 100.f; + std::cerr << "max err: " << max_err; + std::cerr << ", number of errors: " << err_count; + std::cerr << ", " << error_percent << "% wrong values" << std::endl; } return res; } @@ -158,7 +166,11 @@ check_err(const Range& out, } if(!res) { - std::cerr << std::setw(12) << std::setprecision(7) << "max err: " << max_err << std::endl; + const float error_percent = + static_cast(err_count) / static_cast(out.size()) * 100.f; + std::cerr << "max err: " << max_err; + std::cerr << ", number of errors: " << err_count; + std::cerr << ", " << error_percent << "% wrong values" << std::endl; } return res; } @@ -209,7 +221,11 @@ check_err(const Range& out, } if(!res) { - std::cerr << "max err: " << max_err << std::endl; + const float error_percent = + static_cast(err_count) / static_cast(out.size()) * 100.f; + std::cerr << "max err: " << max_err; + std::cerr << ", number of errors: " << err_count; + std::cerr << ", " << error_percent << "% wrong values" << std::endl; } return res; } diff --git a/library/include/ck/library/utility/conv_common.hpp b/library/include/ck/library/utility/conv_common.hpp index 6fad9f7d77d349853e63512ae486a57e7ceb7b6f..085454f42dee006cf7c48ebd18804df92fa36a37 100644 --- a/library/include/ck/library/utility/conv_common.hpp +++ b/library/include/ck/library/utility/conv_common.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/library/include/ck/library/utility/convolution_host_tensor_descriptor_helper.hpp b/library/include/ck/library/utility/convolution_host_tensor_descriptor_helper.hpp index 2b4f63b28b8d2f1e59cef67b99a577314faf4ae2..ff697fb71c0340df3f3aac6953ce76cfdef4b278 100644 --- a/library/include/ck/library/utility/convolution_host_tensor_descriptor_helper.hpp +++ b/library/include/ck/library/utility/convolution_host_tensor_descriptor_helper.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/library/include/ck/library/utility/convolution_parameter.hpp b/library/include/ck/library/utility/convolution_parameter.hpp index f4a2b56f75a6b40a476080e74d386fab524286a4..df6efca10808beb96535ef69a3ba6116a449b9e3 100644 --- a/library/include/ck/library/utility/convolution_parameter.hpp +++ b/library/include/ck/library/utility/convolution_parameter.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/library/include/ck/library/utility/device_memory.hpp b/library/include/ck/library/utility/device_memory.hpp index 87940e1671ab71b0e418c82fbec1f95d6627a8b5..d2e611a77cf6b7e1915de79eb454219dbd746043 100644 --- a/library/include/ck/library/utility/device_memory.hpp +++ b/library/include/ck/library/utility/device_memory.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #pragma once @@ -20,12 +20,15 @@ __global__ void set_buffer_value(T* p, T x, uint64_t buffer_element_size) */ struct DeviceMem { - DeviceMem() = delete; + DeviceMem() : mpDeviceBuf(nullptr), mMemSize(0) {} DeviceMem(std::size_t mem_size); + void Realloc(std::size_t mem_size); void* GetDeviceBuffer() const; std::size_t GetBufferSize() const; void ToDevice(const void* p) const; + void ToDevice(const void* p, const std::size_t cpySize) const; void FromDevice(void* p) const; + void FromDevice(void* p, const std::size_t cpySize) const; void SetZero() const; template void SetValue(T x) const; diff --git a/library/include/ck/library/utility/fill.hpp b/library/include/ck/library/utility/fill.hpp index c0bc3727641041bbd1d9fb0c5e7456c5042b6448..4e075df43b0ffb0cd999ab332a75131e3d741e82 100644 --- a/library/include/ck/library/utility/fill.hpp +++ b/library/include/ck/library/utility/fill.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #pragma once @@ -102,9 +102,10 @@ struct FillMonotonicSeq } template - auto operator()(ForwardRange&& range) const -> std::void_t()(std::begin(std::forward(range)), - std::end(std::forward(range))))> + auto operator()(ForwardRange&& range) const + -> std::void_t()( + std::begin(std::forward(range)), + std::end(std::forward(range))))> { (*this)(std::begin(std::forward(range)), std::end(std::forward(range))); diff --git a/library/include/ck/library/utility/host_common_util.hpp b/library/include/ck/library/utility/host_common_util.hpp index 6f4466e8da0e2b77c1ad981cfb6f730a74a92d71..20a8f234dbbf9f30582ed371c5f1a90cce2b3196 100644 --- a/library/include/ck/library/utility/host_common_util.hpp +++ b/library/include/ck/library/utility/host_common_util.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/library/include/ck/library/utility/host_gemm.hpp b/library/include/ck/library/utility/host_gemm.hpp index 44036d0234375824ed3aa773ffe71de02b84c570..5eb7e3b8c9745c48390482622d361a2ddd491159 100644 --- a/library/include/ck/library/utility/host_gemm.hpp +++ b/library/include/ck/library/utility/host_gemm.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/library/include/ck/library/utility/host_tensor.hpp b/library/include/ck/library/utility/host_tensor.hpp index 844c29ed14ccbc09df5f5f71e9ef2e99f6e8377b..816d8341308d27837f7b8da1b5af19e88f223665 100644 --- a/library/include/ck/library/utility/host_tensor.hpp +++ b/library/include/ck/library/utility/host_tensor.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #pragma once @@ -13,6 +13,7 @@ #include "ck/utility/data_type.hpp" #include "ck/utility/span.hpp" +#include "ck/utility/type_convert.hpp" #include "ck/library/utility/algorithm.hpp" #include "ck/library/utility/ranges.hpp" diff --git a/library/include/ck/library/utility/host_tensor_generator.hpp b/library/include/ck/library/utility/host_tensor_generator.hpp index 4259862e65e331b30377997c512459e2a21e9383..2fdb0b141dda33cddb67785bf3a4eaa07a1c0d4c 100644 --- a/library/include/ck/library/utility/host_tensor_generator.hpp +++ b/library/include/ck/library/utility/host_tensor_generator.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #pragma once @@ -95,6 +95,22 @@ struct GeneratorTensor_2 } }; +#if defined CK_ENABLE_FP8 || defined CK_ENABLE_BF8 +template <> +struct GeneratorTensor_2 +{ + int min_value = 0; + int max_value = 1; + + template + ck::f8_t operator()(Is...) + { + float tmp = (std::rand() % (max_value - min_value)) + min_value; + return ck::type_convert(tmp); + } +}; +#endif + template struct GeneratorTensor_3 { @@ -127,6 +143,25 @@ struct GeneratorTensor_3 } }; +#if defined CK_ENABLE_FP8 || defined CK_ENABLE_BF8 +template <> +struct GeneratorTensor_3 +{ + float min_value = 0; + float max_value = 1; + + template + ck::f8_t operator()(Is...) + { + float tmp = float(std::rand()) / float(RAND_MAX); + + float fp32_tmp = min_value + tmp * (max_value - min_value); + + return ck::type_convert(fp32_tmp); + } +}; +#endif + template struct GeneratorTensor_4 { diff --git a/library/include/ck/library/utility/iterator.hpp b/library/include/ck/library/utility/iterator.hpp index 9fdc88ea76825e2a1cdf659f95fa2928038867aa..b44e2d8e3c4a6f8b4423b03699539ade2fbfb0d7 100644 --- a/library/include/ck/library/utility/iterator.hpp +++ b/library/include/ck/library/utility/iterator.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/library/include/ck/library/utility/literals.hpp b/library/include/ck/library/utility/literals.hpp index a73a2ea054150665557b8e1da473f2cb9877526b..a8bd6303f12f0a36bafb36c1d7f1453eab13392b 100644 --- a/library/include/ck/library/utility/literals.hpp +++ b/library/include/ck/library/utility/literals.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/library/include/ck/library/utility/numeric.hpp b/library/include/ck/library/utility/numeric.hpp index 70a7e87ab1c8d94569e705b3f6832c78bac05853..9ee118d4757109416bac18cb54229e93de59b75c 100644 --- a/library/include/ck/library/utility/numeric.hpp +++ b/library/include/ck/library/utility/numeric.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/library/include/ck/library/utility/ranges.hpp b/library/include/ck/library/utility/ranges.hpp index 55c322f1ace041b9c4f346a5df4250c2323759dd..f11e4204a0cec5378328f69d3bb553a6227c47a3 100644 --- a/library/include/ck/library/utility/ranges.hpp +++ b/library/include/ck/library/utility/ranges.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/library/src/tensor_operation_instance/gpu/CMakeLists.txt b/library/src/tensor_operation_instance/gpu/CMakeLists.txt index c206c4dc04007f942f056137aac78ad38d0c5cb1..1d54a141b7506a7756d349f5a37b34a6037eae0e 100644 --- a/library/src/tensor_operation_instance/gpu/CMakeLists.txt +++ b/library/src/tensor_operation_instance/gpu/CMakeLists.txt @@ -12,9 +12,46 @@ set(CK_DEVICE_INSTANCES) FOREACH(subdir_path ${dir_list}) set(target_dir) IF(IS_DIRECTORY "${subdir_path}") - get_filename_component(target_dir ${subdir_path} NAME) - add_subdirectory(${target_dir}) - list(APPEND CK_DEVICE_INSTANCES $) + set(cmake_instance) + file(READ "${subdir_path}/CMakeLists.txt" cmake_instance) + set(add_inst 0) + if("${cmake_instance}" MATCHES "DTYPES MATCHES \"fp8\" " AND DTYPES MATCHES "fp8") + #message("fp8 instance found!") + set(add_inst 1) + endif() + if("${cmake_instance}" MATCHES "DTYPES MATCHES \"fp16\"" AND DTYPES MATCHES "fp16") + #message("fp16 instance found!") + set(add_inst 1) + endif() + if("${cmake_instance}" MATCHES "DTYPES MATCHES \"fp32\"" AND DTYPES MATCHES "fp32") + #message("fp32 instance found!") + set(add_inst 1) + endif() + if("${cmake_instance}" MATCHES "DTYPES MATCHES \"fp64\"" AND DTYPES MATCHES "fp64") + #message("fp64 instance found!") + set(add_inst 1) + endif() + if("${cmake_instance}" MATCHES "DTYPES MATCHES \"bf16\"" AND DTYPES MATCHES "bf16") + #message("bf16 instance found!") + set(add_inst 1) + endif() + if("${cmake_instance}" MATCHES "DTYPES MATCHES \"int8\"" AND DTYPES MATCHES "int8") + #message("int8 instance found!") + set(add_inst 1) + endif() + if(NOT "${cmake_instance}" MATCHES "DTYPES" OR NOT DEFINED DTYPES) + #message("instance should be built for all types!") + set(add_inst 1) + endif() + if("${cmake_instance}" MATCHES "ONLY DL_KERNELS" AND NOT DEFINED DL_KERNELS) + message("Found only dl instances, but DL_KERNELS is not set. Skipping.") + set(add_inst 0) + endif() + if(add_inst EQUAL 1) + get_filename_component(target_dir ${subdir_path} NAME) + add_subdirectory(${target_dir}) + list(APPEND CK_DEVICE_INSTANCES $) + endif() ENDIF() ENDFOREACH() diff --git a/library/src/tensor_operation_instance/gpu/avg_pool3d_bwd/CMakeLists.txt b/library/src/tensor_operation_instance/gpu/avg_pool3d_bwd/CMakeLists.txt new file mode 100644 index 0000000000000000000000000000000000000000..ec079e3ba3d4778254a97d53aab8a4053ceb7413 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/avg_pool3d_bwd/CMakeLists.txt @@ -0,0 +1,11 @@ +set(DEVICE_AVGPOOL_BWD_INSTANCES) +if(DTYPES MATCHES "fp16" OR NOT DEFINED DTYPES) + list(APPEND DEVICE_AVGPOOL_BWD_INSTANCES device_avg_pool3d_bwd_ndhwc_f16_instance.cpp) +endif() +if(DTYPES MATCHES "bf16" OR NOT DEFINED DTYPES) + list(APPEND DEVICE_AVGPOOL_BWD_INSTANCES device_avg_pool3d_bwd_ndhwc_bf16_instance.cpp) +endif() +if(DTYPES MATCHES "fp32" OR NOT DEFINED DTYPES) + list(APPEND DEVICE_AVGPOOL_BWD_INSTANCES device_avg_pool3d_bwd_ndhwc_f32_instance.cpp) +endif() +add_instance_library(device_avg_pool3d_bwd_instance ${DEVICE_AVGPOOL_BWD_INSTANCES}) diff --git a/library/src/tensor_operation_instance/gpu/avg_pool3d_bwd/avg_pool3d_bwd_ndhwc_instance_common.hpp b/library/src/tensor_operation_instance/gpu/avg_pool3d_bwd/avg_pool3d_bwd_ndhwc_instance_common.hpp new file mode 100644 index 0000000000000000000000000000000000000000..c989bbcd3df35233128d23d719dbb0fa70099731 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/avg_pool3d_bwd/avg_pool3d_bwd_ndhwc_instance_common.hpp @@ -0,0 +1,59 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_avgpool3d_bwd_ndhwc_ndhwc.hpp" +#include "ck/utility/data_type.hpp" + +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +using I32 = int32_t; +using F16 = ck::half_t; +using BF16 = ck::bhalf_t; +using F32 = float; +using NDHWC = ck::tensor_layout::convolution::NDHWC; + +using device_avgpool_bwd_ndhwc_f16_instances = + // clang-format off + std::tuple < + DeviceAvgPool3dBwd_NDHWC_NDHWC, + DeviceAvgPool3dBwd_NDHWC_NDHWC, + DeviceAvgPool3dBwd_NDHWC_NDHWC, + DeviceAvgPool3dBwd_NDHWC_NDHWC, + DeviceAvgPool3dBwd_NDHWC_NDHWC + // clang-format on + >; + +using device_avgpool_bwd_ndhwc_bf16_instances = + // clang-format off + std::tuple < + DeviceAvgPool3dBwd_NDHWC_NDHWC, + DeviceAvgPool3dBwd_NDHWC_NDHWC, + DeviceAvgPool3dBwd_NDHWC_NDHWC, + DeviceAvgPool3dBwd_NDHWC_NDHWC, + DeviceAvgPool3dBwd_NDHWC_NDHWC + // clang-format on + >; + +using device_avgpool_bwd_ndhwc_f32_instances = + // clang-format off + std::tuple < + DeviceAvgPool3dBwd_NDHWC_NDHWC, + DeviceAvgPool3dBwd_NDHWC_NDHWC, + DeviceAvgPool3dBwd_NDHWC_NDHWC, + DeviceAvgPool3dBwd_NDHWC_NDHWC, + DeviceAvgPool3dBwd_NDHWC_NDHWC + // clang-format on + >; + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/avg_pool3d_bwd/device_avg_pool3d_bwd_ndhwc_bf16_instance.cpp b/library/src/tensor_operation_instance/gpu/avg_pool3d_bwd/device_avg_pool3d_bwd_ndhwc_bf16_instance.cpp new file mode 100644 index 0000000000000000000000000000000000000000..52a8852f3039235c69ebe2d8c63603a271e2509a --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/avg_pool3d_bwd/device_avg_pool3d_bwd_ndhwc_bf16_instance.cpp @@ -0,0 +1,20 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#include "avg_pool3d_bwd_ndhwc_instance_common.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +void add_device_avgpool_bwd_ndhwc_bf16_instances( + std::vector>>& instances) +{ + add_device_operation_instances(instances, device_avgpool_bwd_ndhwc_bf16_instances{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/avg_pool3d_bwd/device_avg_pool3d_bwd_ndhwc_f16_instance.cpp b/library/src/tensor_operation_instance/gpu/avg_pool3d_bwd/device_avg_pool3d_bwd_ndhwc_f16_instance.cpp new file mode 100644 index 0000000000000000000000000000000000000000..50de10e78eaa1dee44e6d4aadee132216625c479 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/avg_pool3d_bwd/device_avg_pool3d_bwd_ndhwc_f16_instance.cpp @@ -0,0 +1,20 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#include "avg_pool3d_bwd_ndhwc_instance_common.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +void add_device_avgpool_bwd_ndhwc_f16_instances( + std::vector>>& instances) +{ + add_device_operation_instances(instances, device_avgpool_bwd_ndhwc_f16_instances{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/avg_pool3d_bwd/device_avg_pool3d_bwd_ndhwc_f32_instance.cpp b/library/src/tensor_operation_instance/gpu/avg_pool3d_bwd/device_avg_pool3d_bwd_ndhwc_f32_instance.cpp new file mode 100644 index 0000000000000000000000000000000000000000..0d4bb9a67f58a4b95d787fade0b376b211a43126 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/avg_pool3d_bwd/device_avg_pool3d_bwd_ndhwc_f32_instance.cpp @@ -0,0 +1,20 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#include "avg_pool3d_bwd_ndhwc_instance_common.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +void add_device_avgpool_bwd_ndhwc_f32_instances( + std::vector>>& instances) +{ + add_device_operation_instances(instances, device_avgpool_bwd_ndhwc_f32_instances{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/batched_gemm/CMakeLists.txt b/library/src/tensor_operation_instance/gpu/batched_gemm/CMakeLists.txt index 0f2a7391999db23c9df9578b687ee11fa2198b26..5b342595d888ee18f2f558da1aeefd8be2fe2f85 100644 --- a/library/src/tensor_operation_instance/gpu/batched_gemm/CMakeLists.txt +++ b/library/src/tensor_operation_instance/gpu/batched_gemm/CMakeLists.txt @@ -1,18 +1,26 @@ -add_instance_library(device_batched_gemm_instance - device_batched_gemm_xdl_f16_f16_f16_gmk_gkn_gmn_instance.cpp - device_batched_gemm_xdl_f16_f16_f16_gmk_gnk_gmn_instance.cpp - device_batched_gemm_xdl_f16_f16_f16_gkm_gkn_gmn_instance.cpp - device_batched_gemm_xdl_f16_f16_f16_gkm_gnk_gmn_instance.cpp - device_batched_gemm_xdl_bf16_bf16_bf16_gmk_gkn_gmn_instance.cpp - device_batched_gemm_xdl_bf16_bf16_bf16_gmk_gnk_gmn_instance.cpp - device_batched_gemm_xdl_bf16_bf16_bf16_gkm_gkn_gmn_instance.cpp - device_batched_gemm_xdl_bf16_bf16_bf16_gkm_gnk_gmn_instance.cpp - device_batched_gemm_xdl_f32_f32_f32_gmk_gkn_gmn_instance.cpp - device_batched_gemm_xdl_f32_f32_f32_gmk_gnk_gmn_instance.cpp - device_batched_gemm_xdl_f32_f32_f32_gkm_gkn_gmn_instance.cpp - device_batched_gemm_xdl_f32_f32_f32_gkm_gnk_gmn_instance.cpp - device_batched_gemm_xdl_int8_int8_int8_gmk_gkn_gmn_instance.cpp - device_batched_gemm_xdl_int8_int8_int8_gmk_gnk_gmn_instance.cpp - device_batched_gemm_xdl_int8_int8_int8_gkm_gkn_gmn_instance.cpp - device_batched_gemm_xdl_int8_int8_int8_gkm_gnk_gmn_instance.cpp -) +set(BATCHED_GEMM_INSTANCES) +if(DTYPES MATCHES "fp16" OR NOT DEFINED DTYPES) + list(APPEND BATCHED_GEMM_INSTANCES device_batched_gemm_xdl_f16_f16_f16_gmk_gkn_gmn_instance.cpp + device_batched_gemm_xdl_f16_f16_f16_gmk_gnk_gmn_instance.cpp + device_batched_gemm_xdl_f16_f16_f16_gkm_gkn_gmn_instance.cpp + device_batched_gemm_xdl_f16_f16_f16_gkm_gnk_gmn_instance.cpp) +endif() +if(DTYPES MATCHES "bf16" OR NOT DEFINED DTYPES) + list(APPEND BATCHED_GEMM_INSTANCES device_batched_gemm_xdl_bf16_bf16_bf16_gmk_gkn_gmn_instance.cpp + device_batched_gemm_xdl_bf16_bf16_bf16_gmk_gnk_gmn_instance.cpp + device_batched_gemm_xdl_bf16_bf16_bf16_gkm_gkn_gmn_instance.cpp + device_batched_gemm_xdl_bf16_bf16_bf16_gkm_gnk_gmn_instance.cpp) +endif() +if(DTYPES MATCHES "fp32" OR NOT DEFINED DTYPES) + list(APPEND BATCHED_GEMM_INSTANCES device_batched_gemm_xdl_f32_f32_f32_gmk_gkn_gmn_instance.cpp + device_batched_gemm_xdl_f32_f32_f32_gmk_gnk_gmn_instance.cpp + device_batched_gemm_xdl_f32_f32_f32_gkm_gkn_gmn_instance.cpp + device_batched_gemm_xdl_f32_f32_f32_gkm_gnk_gmn_instance.cpp) +endif() +if(DTYPES MATCHES "int8" OR NOT DEFINED DTYPES) + list(APPEND BATCHED_GEMM_INSTANCES device_batched_gemm_xdl_int8_int8_int8_gmk_gkn_gmn_instance.cpp + device_batched_gemm_xdl_int8_int8_int8_gmk_gnk_gmn_instance.cpp + device_batched_gemm_xdl_int8_int8_int8_gkm_gkn_gmn_instance.cpp + device_batched_gemm_xdl_int8_int8_int8_gkm_gnk_gmn_instance.cpp) +endif() +add_instance_library(device_batched_gemm_instance ${BATCHED_GEMM_INSTANCES}) diff --git a/library/src/tensor_operation_instance/gpu/batched_gemm/device_batched_gemm_xdl_bf16_bf16_bf16_gkm_gkn_gmn_instance.cpp b/library/src/tensor_operation_instance/gpu/batched_gemm/device_batched_gemm_xdl_bf16_bf16_bf16_gkm_gkn_gmn_instance.cpp index cc8787458d7ffff072cf8c055c8def43ca487e51..e730e9f589faf8047dd60a218ef191eea510cb63 100644 --- a/library/src/tensor_operation_instance/gpu/batched_gemm/device_batched_gemm_xdl_bf16_bf16_bf16_gkm_gkn_gmn_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/batched_gemm/device_batched_gemm_xdl_bf16_bf16_bf16_gkm_gkn_gmn_instance.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #include diff --git a/library/src/tensor_operation_instance/gpu/batched_gemm/device_batched_gemm_xdl_bf16_bf16_bf16_gkm_gnk_gmn_instance.cpp b/library/src/tensor_operation_instance/gpu/batched_gemm/device_batched_gemm_xdl_bf16_bf16_bf16_gkm_gnk_gmn_instance.cpp index 04200cfb52e26f227d89852e2845547e6624c22d..f6696ffa9f6ca2c0665ecde3f0394c0802588a46 100644 --- a/library/src/tensor_operation_instance/gpu/batched_gemm/device_batched_gemm_xdl_bf16_bf16_bf16_gkm_gnk_gmn_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/batched_gemm/device_batched_gemm_xdl_bf16_bf16_bf16_gkm_gnk_gmn_instance.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #include diff --git a/library/src/tensor_operation_instance/gpu/batched_gemm/device_batched_gemm_xdl_bf16_bf16_bf16_gmk_gkn_gmn_instance.cpp b/library/src/tensor_operation_instance/gpu/batched_gemm/device_batched_gemm_xdl_bf16_bf16_bf16_gmk_gkn_gmn_instance.cpp index 7b86f3cc728fdac1117cbe927445b6830e81e9ab..32d6f258b5c9152205f8b13e0326739b567b0845 100644 --- a/library/src/tensor_operation_instance/gpu/batched_gemm/device_batched_gemm_xdl_bf16_bf16_bf16_gmk_gkn_gmn_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/batched_gemm/device_batched_gemm_xdl_bf16_bf16_bf16_gmk_gkn_gmn_instance.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #include diff --git a/library/src/tensor_operation_instance/gpu/batched_gemm/device_batched_gemm_xdl_bf16_bf16_bf16_gmk_gnk_gmn_instance.cpp b/library/src/tensor_operation_instance/gpu/batched_gemm/device_batched_gemm_xdl_bf16_bf16_bf16_gmk_gnk_gmn_instance.cpp index 2afb1afbc3d2355f682af9cad21b62c0ab1ee8ea..ee246ba56edd7271d7d2f980c62003a7ab7c7f06 100644 --- a/library/src/tensor_operation_instance/gpu/batched_gemm/device_batched_gemm_xdl_bf16_bf16_bf16_gmk_gnk_gmn_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/batched_gemm/device_batched_gemm_xdl_bf16_bf16_bf16_gmk_gnk_gmn_instance.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #include diff --git a/library/src/tensor_operation_instance/gpu/batched_gemm/device_batched_gemm_xdl_f16_f16_f16_gkm_gkn_gmn_instance.cpp b/library/src/tensor_operation_instance/gpu/batched_gemm/device_batched_gemm_xdl_f16_f16_f16_gkm_gkn_gmn_instance.cpp index 68d768949c859f5ca8aedfebfa888253dc6c276c..5a9483b3099d9985f88f0121135a79d55edefd5a 100644 --- a/library/src/tensor_operation_instance/gpu/batched_gemm/device_batched_gemm_xdl_f16_f16_f16_gkm_gkn_gmn_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/batched_gemm/device_batched_gemm_xdl_f16_f16_f16_gkm_gkn_gmn_instance.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #include diff --git a/library/src/tensor_operation_instance/gpu/batched_gemm/device_batched_gemm_xdl_f16_f16_f16_gkm_gnk_gmn_instance.cpp b/library/src/tensor_operation_instance/gpu/batched_gemm/device_batched_gemm_xdl_f16_f16_f16_gkm_gnk_gmn_instance.cpp index 737e5bfca3bccbafcd8179d80bb43b58bd545639..0fa0719237b000b7ab9b13d94cfa951edacf3139 100644 --- a/library/src/tensor_operation_instance/gpu/batched_gemm/device_batched_gemm_xdl_f16_f16_f16_gkm_gnk_gmn_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/batched_gemm/device_batched_gemm_xdl_f16_f16_f16_gkm_gnk_gmn_instance.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #include diff --git a/library/src/tensor_operation_instance/gpu/batched_gemm/device_batched_gemm_xdl_f16_f16_f16_gmk_gkn_gmn_instance.cpp b/library/src/tensor_operation_instance/gpu/batched_gemm/device_batched_gemm_xdl_f16_f16_f16_gmk_gkn_gmn_instance.cpp index e09d01736083f54cf2c1b848814b2d7fb683d3b8..dc7de8c6893c51d465260252941030c2b3e76017 100644 --- a/library/src/tensor_operation_instance/gpu/batched_gemm/device_batched_gemm_xdl_f16_f16_f16_gmk_gkn_gmn_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/batched_gemm/device_batched_gemm_xdl_f16_f16_f16_gmk_gkn_gmn_instance.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #include @@ -25,6 +25,16 @@ using S = ck::Sequence; using PassThrough = ck::tensor_operation::element_wise::PassThrough; +using device_batched_gemm_xdl_f16_f16_f16_gmk_gkn_gmn_generic_instances = std::tuple< + // clang-format off + //#################| AData| BData| CData| AccData| ALayout| BLayout| CLayout| A| B| C| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CThreadTransfer| CThreadTransfer| NumGemmK| LoopScheduler| Pipeline| + //#################| Type| Type| Type| Type| | | | Elementwise| Elementwise| Elementwise| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| SrcDstVectorDim| DstScalar| Prefetch| | | + //#################| | | | | | | | Operation| Operation| Operation| | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | | PerVector| Stage | | | + //#################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceBatchedGemmXdl< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, 64, 16, 16, 4, 8, 16, 16, 1, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 8, true, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, true, 7, 1, 1, LoopScheduler::Default, PipelineVersion::v1> + // clang-format on + >; + // Compilation parameters for a[m, k] * b[k, n] = c[m, n] using device_batched_gemm_xdl_f16_f16_f16_gmk_gkn_gmn_instances = std::tuple< // clang-format off @@ -100,6 +110,8 @@ void add_device_batched_gemm_xdl_f16_f16_f16_gmk_gkn_gmn_instances( DeviceBatchedGemm>>& instances) { + add_device_operation_instances( + instances, device_batched_gemm_xdl_f16_f16_f16_gmk_gkn_gmn_generic_instances{}); add_device_operation_instances(instances, device_batched_gemm_xdl_f16_f16_f16_gmk_gkn_gmn_instances{}); } diff --git a/library/src/tensor_operation_instance/gpu/batched_gemm/device_batched_gemm_xdl_f16_f16_f16_gmk_gnk_gmn_instance.cpp b/library/src/tensor_operation_instance/gpu/batched_gemm/device_batched_gemm_xdl_f16_f16_f16_gmk_gnk_gmn_instance.cpp index 984d66e2884273e006772530e6a978891766a8e0..cccad7ca1928cd824fa0547d125f95dae0b732b5 100644 --- a/library/src/tensor_operation_instance/gpu/batched_gemm/device_batched_gemm_xdl_f16_f16_f16_gmk_gnk_gmn_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/batched_gemm/device_batched_gemm_xdl_f16_f16_f16_gmk_gnk_gmn_instance.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #include @@ -25,6 +25,16 @@ using S = ck::Sequence; using PassThrough = ck::tensor_operation::element_wise::PassThrough; +using device_batched_gemm_xdl_f16_f16_f16_gmk_gnk_gmn_generic_instances = std::tuple< + // clang-format off + //#################| AData| BData| CData| AccData| ALayout| BLayout| CLayout| A| B| C| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CThreadTransfer| CThreadTransfer| NumGemmK| LoopScheduler| Pipeline| + //#################| Type| Type| Type| Type| | | | Elementwise| Elementwise| Elementwise| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| SrcDstVectorDim| DstScalar| Prefetch| | | + //#################| | | | | | | | Operation| Operation| Operation| | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | | PerVector| Stage | | | + //#################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceBatchedGemmXdl< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, 64, 32, 64, 4, 8, 32, 32, 1, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 8, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 8, true, 7, 1, 1, LoopScheduler::Default, PipelineVersion::v1> + // clang-format on + >; + // Compilation parameters for a[m, k] * b[n, k] = c[m, n] using device_batched_gemm_xdl_f16_f16_f16_gmk_gnk_gmn_instances = std::tuple< // clang-format off @@ -88,6 +98,8 @@ void add_device_batched_gemm_xdl_f16_f16_f16_gmk_gnk_gmn_instances( DeviceBatchedGemm>>& instances) { + add_device_operation_instances( + instances, device_batched_gemm_xdl_f16_f16_f16_gmk_gnk_gmn_generic_instances{}); add_device_operation_instances(instances, device_batched_gemm_xdl_f16_f16_f16_gmk_gnk_gmn_instances{}); } diff --git a/library/src/tensor_operation_instance/gpu/batched_gemm/device_batched_gemm_xdl_f32_f32_f32_gkm_gkn_gmn_instance.cpp b/library/src/tensor_operation_instance/gpu/batched_gemm/device_batched_gemm_xdl_f32_f32_f32_gkm_gkn_gmn_instance.cpp index 12cada9c44d61d8b81ad2dfa4877dc57b26ed681..c687eb20f06918b41ca54fa72a2a577b809eb2cf 100644 --- a/library/src/tensor_operation_instance/gpu/batched_gemm/device_batched_gemm_xdl_f32_f32_f32_gkm_gkn_gmn_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/batched_gemm/device_batched_gemm_xdl_f32_f32_f32_gkm_gkn_gmn_instance.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #include diff --git a/library/src/tensor_operation_instance/gpu/batched_gemm/device_batched_gemm_xdl_f32_f32_f32_gkm_gnk_gmn_instance.cpp b/library/src/tensor_operation_instance/gpu/batched_gemm/device_batched_gemm_xdl_f32_f32_f32_gkm_gnk_gmn_instance.cpp index 13f198862e5583b1f63c51efaab11fa733976cfe..b19374ca6d5635238da838808d863b19fd1b6b93 100644 --- a/library/src/tensor_operation_instance/gpu/batched_gemm/device_batched_gemm_xdl_f32_f32_f32_gkm_gnk_gmn_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/batched_gemm/device_batched_gemm_xdl_f32_f32_f32_gkm_gnk_gmn_instance.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #include diff --git a/library/src/tensor_operation_instance/gpu/batched_gemm/device_batched_gemm_xdl_f32_f32_f32_gmk_gkn_gmn_instance.cpp b/library/src/tensor_operation_instance/gpu/batched_gemm/device_batched_gemm_xdl_f32_f32_f32_gmk_gkn_gmn_instance.cpp index 2ca1adc2f6d160a97a2f40b4075f07b185e68351..bbd318ba9299005910d63a83d6906f4b69571479 100644 --- a/library/src/tensor_operation_instance/gpu/batched_gemm/device_batched_gemm_xdl_f32_f32_f32_gmk_gkn_gmn_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/batched_gemm/device_batched_gemm_xdl_f32_f32_f32_gmk_gkn_gmn_instance.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #include diff --git a/library/src/tensor_operation_instance/gpu/batched_gemm/device_batched_gemm_xdl_f32_f32_f32_gmk_gnk_gmn_instance.cpp b/library/src/tensor_operation_instance/gpu/batched_gemm/device_batched_gemm_xdl_f32_f32_f32_gmk_gnk_gmn_instance.cpp index fe5de52796f6c9a0c94a84d883a7b1a714c4a652..187ccb5ff62fab57a83139f340a4032fd03e8dea 100644 --- a/library/src/tensor_operation_instance/gpu/batched_gemm/device_batched_gemm_xdl_f32_f32_f32_gmk_gnk_gmn_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/batched_gemm/device_batched_gemm_xdl_f32_f32_f32_gmk_gnk_gmn_instance.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #include diff --git a/library/src/tensor_operation_instance/gpu/batched_gemm/device_batched_gemm_xdl_int8_int8_int8_gkm_gkn_gmn_instance.cpp b/library/src/tensor_operation_instance/gpu/batched_gemm/device_batched_gemm_xdl_int8_int8_int8_gkm_gkn_gmn_instance.cpp index 5b55c8e15e01d879cc11e57128c3e2dfdf81890f..ec2b2646ffa1a78b058a90b0508eba7f1d65a2be 100644 --- a/library/src/tensor_operation_instance/gpu/batched_gemm/device_batched_gemm_xdl_int8_int8_int8_gkm_gkn_gmn_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/batched_gemm/device_batched_gemm_xdl_int8_int8_int8_gkm_gkn_gmn_instance.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #include diff --git a/library/src/tensor_operation_instance/gpu/batched_gemm/device_batched_gemm_xdl_int8_int8_int8_gkm_gnk_gmn_instance.cpp b/library/src/tensor_operation_instance/gpu/batched_gemm/device_batched_gemm_xdl_int8_int8_int8_gkm_gnk_gmn_instance.cpp index 9517e4577e5b9af188aaf07413e8c0b3cf6c32ee..d76cd350c81807f330f7b04ae2ed2d1bf6bb5080 100644 --- a/library/src/tensor_operation_instance/gpu/batched_gemm/device_batched_gemm_xdl_int8_int8_int8_gkm_gnk_gmn_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/batched_gemm/device_batched_gemm_xdl_int8_int8_int8_gkm_gnk_gmn_instance.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #include diff --git a/library/src/tensor_operation_instance/gpu/batched_gemm/device_batched_gemm_xdl_int8_int8_int8_gmk_gkn_gmn_instance.cpp b/library/src/tensor_operation_instance/gpu/batched_gemm/device_batched_gemm_xdl_int8_int8_int8_gmk_gkn_gmn_instance.cpp index 43b91244060ed90c9f2558dc99279ce0a086845b..ef65106c2cf537ae29dcc7313829ecaf3ae0ed58 100644 --- a/library/src/tensor_operation_instance/gpu/batched_gemm/device_batched_gemm_xdl_int8_int8_int8_gmk_gkn_gmn_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/batched_gemm/device_batched_gemm_xdl_int8_int8_int8_gmk_gkn_gmn_instance.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #include diff --git a/library/src/tensor_operation_instance/gpu/batched_gemm/device_batched_gemm_xdl_int8_int8_int8_gmk_gnk_gmn_instance.cpp b/library/src/tensor_operation_instance/gpu/batched_gemm/device_batched_gemm_xdl_int8_int8_int8_gmk_gnk_gmn_instance.cpp index 326500fcbf013e8540fff9252202ac7af33b2938..078b241f9dfe194778806824911d791c7dd7a2de 100644 --- a/library/src/tensor_operation_instance/gpu/batched_gemm/device_batched_gemm_xdl_int8_int8_int8_gmk_gnk_gmn_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/batched_gemm/device_batched_gemm_xdl_int8_int8_int8_gmk_gnk_gmn_instance.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #include diff --git a/library/src/tensor_operation_instance/gpu/batched_gemm_add_relu_gemm_add/CMakeLists.txt b/library/src/tensor_operation_instance/gpu/batched_gemm_add_relu_gemm_add/CMakeLists.txt index d0e9b265af541a7717e9d25a3aef96e4196298ba..6710035ec6ba0d3cf0b1c28a3561b4d0648ff9ba 100644 --- a/library/src/tensor_operation_instance/gpu/batched_gemm_add_relu_gemm_add/CMakeLists.txt +++ b/library/src/tensor_operation_instance/gpu/batched_gemm_add_relu_gemm_add/CMakeLists.txt @@ -1,4 +1,6 @@ +if(DTYPES MATCHES "fp16" OR NOT DEFINED DTYPES) add_instance_library(device_batched_gemm_add_relu_gemm_add_instance device_batched_gemm_add_relu_gemm_add_xdl_cshuffle_f16_f16_f16_f16_gmk_gnk_gno_gmo_instance.cpp device_batched_gemm_add_relu_gemm_add_xdl_cshuffle_f16_f16_f16_f16_gmk_gnk_gon_gmo_instance.cpp ) +endif() \ No newline at end of file diff --git a/library/src/tensor_operation_instance/gpu/batched_gemm_add_relu_gemm_add/device_batched_gemm_add_relu_gemm_add_xdl_cshuffle_f16_f16_f16_f16_gmk_gnk_gno_gmo_instance.cpp b/library/src/tensor_operation_instance/gpu/batched_gemm_add_relu_gemm_add/device_batched_gemm_add_relu_gemm_add_xdl_cshuffle_f16_f16_f16_f16_gmk_gnk_gno_gmo_instance.cpp index e1bfa88f49d4fb886f953661c03b8638516ba646..e750f18cc7f54b60a5bc4303fe3da8042959cadc 100644 --- a/library/src/tensor_operation_instance/gpu/batched_gemm_add_relu_gemm_add/device_batched_gemm_add_relu_gemm_add_xdl_cshuffle_f16_f16_f16_f16_gmk_gnk_gno_gmo_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/batched_gemm_add_relu_gemm_add/device_batched_gemm_add_relu_gemm_add_xdl_cshuffle_f16_f16_f16_f16_gmk_gnk_gno_gmo_instance.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #include @@ -31,22 +31,24 @@ using CDE1ElementOp = ck::tensor_operation::element_wise::Add; using device_batched_gemm_add_relu_gemm_add_xdl_cshuffle_f16_f16_f16_f16_gmk_gnk_gno_gmo_instances = std::tuple< // clang-format off - //##################################################| A0Layout| B0Layout| D0Layout| B1Layout| D1sLayout| E1Layout| A0Data| B0Data| Acc0DataType| D0DataType| B1Data| Acc1CData| CShuffle| D1sData| E1Data| A0| B0| CDE0| B1| CDE1| PadGemm0M| PadGemm0N| PadGemm0K| PadGemm1N| PadGemm1K|NumGemm0K| Block| Gemm0| Gemm0| Gemm0| Gemm1| Gemm1|A0K1|B0K1| B1K1| MPer| NPer| Gemm0| Gemm0| Gemm1|A0BlockTransfer|A0BlockTransfer|A0BlockTransfer|A0BlockTransfer|A0BlockTransfer|A0BlockTransfer|A0BlockLds| B0BlockTransfer| B0BlockTransfer| B0BlockTransfer| B0BlockTransfer| B0BlockTransfer| B0BlockTransfer| B0BlockLds| B1BlockTransfer| B1BlockTransfer| B1BlockTransfer| B1BlockTransfer| B1BlockTransfer| B1BlockTransfer| B1BlockLds| C1Shuffle| C1Shuffle| CDE1BlockTransferClusterLengths| CDE1BlockTransfer| - //##################################################| | | | | | | Type| Type| Type| Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Elementwise| Elementwise| | | | | | Prefetch| Size| MPer| NPer| KPer| NPer| KPer| | | | XDL| XDL| MXdl| NXdl| NXdl| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| - //##################################################| | | | | | | | | | | | | | | | Operation| Operation| Operation| Operation| Operation| | | | | | Stage| | Block| Block| Block| Block| Block| | | | | | Per| Per| Per|Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_AK1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| - //##################################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | Wave| Wave| Wave| | | | | | | | | | | | | | | | | | | | | | | | | | + //##################################################| A0Layout| B0Layout| D0Layout| B1Layout| D1sLayout| E1Layout| A0Data| B0Data| Acc0DataType| D0DataType| B1Data| Acc1CData| CShuffle| D1sData| E1Data| A0| B0| CDE0| B1| CDE1| PadGemm0M| PadGemm0N| PadGemm0K| PadGemm1N| PadGemm1K|NumGemm0K| Block| Gemm0| Gemm0| Gemm0| Gemm1| Gemm1|A0K1|B0K1| B1K1| MPer| NPer| Gemm0| Gemm0| Gemm1|A0BlockTransfer|A0BlockTransfer|A0BlockTransfer|A0BlockTransfer|A0BlockTransfer|A0BlockTransfer|A0BlockLds| B0BlockTransfer| B0BlockTransfer| B0BlockTransfer| B0BlockTransfer| B0BlockTransfer| B0BlockTransfer| B0BlockLds| CDE0BlockTransfer| CDE0BlockTransfer| B1BlockTransfer| B1BlockTransfer| B1BlockTransfer| B1BlockTransfer| B1BlockTransfer| B1BlockTransfer| B1BlockLds| C1Shuffle| C1Shuffle| CDE1BlockTransferClusterLengths| CDE1BlockTransfer| + //##################################################| | | | | | | Type| Type| Type| Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Elementwise| Elementwise| | | | | | Prefetch| Size| MPer| NPer| KPer| NPer| KPer| | | | XDL| XDL| MXdl| NXdl| NXdl| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| SrcVectorDim| SrcScalar| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| + //##################################################| | | | | | | | | | | | | | | | Operation| Operation| Operation| Operation| Operation| | | | | | Stage| | Block| Block| Block| Block| Block| | | | | | Per| Per| Per|Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_AK1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | | PerVector| Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| + //##################################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | Wave| Wave| Wave| | | | | | | | | | | | | | | | | | | | | | | | | | | | + //generic + DeviceBatchedGemmMultipleDGemmMultipleD_Xdl_CShuffle< Row, Col, ck::Tuple, Row, ck::Tuple, Row, F16, F16, F32, ck::Tuple, F16, F32, F32, ck::Tuple, F16, PassThrough, PassThrough, CDE0ElementOp, PassThrough, CDE1ElementOp, true, true, true, true, true, 1, 256, 128, 64, 32, 128, 32, 8, 8, 2, 32, 32, 1, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 9, 1, S< 8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 2, S<1, 32, 1, 8>, 8>, // no padding - DeviceBatchedGemmMultipleDGemmMultipleD_Xdl_CShuffle< Row, Col, ck::Tuple, Row, ck::Tuple, Row, F16, F16, F32, ck::Tuple, F16, F32, F32, ck::Tuple, F16, PassThrough, PassThrough, CDE0ElementOp, PassThrough, CDE1ElementOp, false, false, false, false, false, 1, 256, 128, 128, 64, 64, 32, 8, 8, 2, 32, 32, 1, 4, 2, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, false, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, false, S<16, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 2, S<1, 32, 1, 8>, 8>, - DeviceBatchedGemmMultipleDGemmMultipleD_Xdl_CShuffle< Row, Col, ck::Tuple, Row, ck::Tuple, Row, F16, F16, F32, ck::Tuple, F16, F32, F32, ck::Tuple, F16, PassThrough, PassThrough, CDE0ElementOp, PassThrough, CDE1ElementOp, false, false, false, false, false, 1, 256, 128, 128, 32, 64, 32, 8, 8, 2, 32, 32, 1, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<16, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 2, S<1, 32, 1, 8>, 8>, - DeviceBatchedGemmMultipleDGemmMultipleD_Xdl_CShuffle< Row, Col, ck::Tuple, Row, ck::Tuple, Row, F16, F16, F32, ck::Tuple, F16, F32, F32, ck::Tuple, F16, PassThrough, PassThrough, CDE0ElementOp, PassThrough, CDE1ElementOp, false, false, false, false, false, 1, 256, 128, 128, 64, 128, 32, 8, 8, 2, 32, 32, 1, 4, 4, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, false, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, false, S< 8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 2, S<1, 32, 1, 8>, 8>, - DeviceBatchedGemmMultipleDGemmMultipleD_Xdl_CShuffle< Row, Col, ck::Tuple, Row, ck::Tuple, Row, F16, F16, F32, ck::Tuple, F16, F32, F32, ck::Tuple, F16, PassThrough, PassThrough, CDE0ElementOp, PassThrough, CDE1ElementOp, false, false, false, false, false, 1, 256, 128, 128, 32, 128, 32, 8, 8, 2, 32, 32, 1, 4, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S< 8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 2, S<1, 32, 1, 8>, 8>, - DeviceBatchedGemmMultipleDGemmMultipleD_Xdl_CShuffle< Row, Col, ck::Tuple, Row, ck::Tuple, Row, F16, F16, F32, ck::Tuple, F16, F32, F32, ck::Tuple, F16, PassThrough, PassThrough, CDE0ElementOp, PassThrough, CDE1ElementOp, false, false, false, false, false, 1, 256, 64, 256, 32, 128, 32, 8, 8, 2, 16, 16, 1, 16, 8, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S< 8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 8, S<1, 16, 1,16>, 8>, - DeviceBatchedGemmMultipleDGemmMultipleD_Xdl_CShuffle< Row, Col, ck::Tuple, Row, ck::Tuple, Row, F16, F16, F32, ck::Tuple, F16, F32, F32, ck::Tuple, F16, PassThrough, PassThrough, CDE0ElementOp, PassThrough, CDE1ElementOp, false, false, false, false, false, 1, 256, 64, 256, 32, 64, 32, 8, 8, 2, 16, 16, 1, 16, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<16, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 4, S<1, 32, 1, 8>, 8>, - DeviceBatchedGemmMultipleDGemmMultipleD_Xdl_CShuffle< Row, Col, ck::Tuple, Row, ck::Tuple, Row, F16, F16, F32, ck::Tuple, F16, F32, F32, ck::Tuple, F16, PassThrough, PassThrough, CDE0ElementOp, PassThrough, CDE1ElementOp, false, false, false, false, false, 1, 256, 64, 256, 64, 128, 32, 8, 8, 2, 16, 16, 1, 16, 8, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S< 8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 8, S<1, 16, 1,16>, 8>, - DeviceBatchedGemmMultipleDGemmMultipleD_Xdl_CShuffle< Row, Col, ck::Tuple, Row, ck::Tuple, Row, F16, F16, F32, ck::Tuple, F16, F32, F32, ck::Tuple, F16, PassThrough, PassThrough, CDE0ElementOp, PassThrough, CDE1ElementOp, false, false, false, false, false, 1, 256, 64, 256, 64, 64, 32, 8, 8, 2, 16, 16, 1, 16, 4, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<16, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 4, S<1, 32, 1, 8>, 8>, + DeviceBatchedGemmMultipleDGemmMultipleD_Xdl_CShuffle< Row, Col, ck::Tuple, Row, ck::Tuple, Row, F16, F16, F32, ck::Tuple, F16, F32, F32, ck::Tuple, F16, PassThrough, PassThrough, CDE0ElementOp, PassThrough, CDE1ElementOp, false, false, false, false, false, 1, 256, 128, 128, 64, 64, 32, 8, 8, 2, 32, 32, 1, 4, 2, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, false, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, false, 9, 4, S<16, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 2, S<1, 32, 1, 8>, 8>, + DeviceBatchedGemmMultipleDGemmMultipleD_Xdl_CShuffle< Row, Col, ck::Tuple, Row, ck::Tuple, Row, F16, F16, F32, ck::Tuple, F16, F32, F32, ck::Tuple, F16, PassThrough, PassThrough, CDE0ElementOp, PassThrough, CDE1ElementOp, false, false, false, false, false, 1, 256, 128, 128, 32, 64, 32, 8, 8, 2, 32, 32, 1, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 9, 4, S<16, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 2, S<1, 32, 1, 8>, 8>, + DeviceBatchedGemmMultipleDGemmMultipleD_Xdl_CShuffle< Row, Col, ck::Tuple, Row, ck::Tuple, Row, F16, F16, F32, ck::Tuple, F16, F32, F32, ck::Tuple, F16, PassThrough, PassThrough, CDE0ElementOp, PassThrough, CDE1ElementOp, false, false, false, false, false, 1, 256, 128, 128, 64, 128, 32, 8, 8, 2, 32, 32, 1, 4, 4, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, false, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, false, 9, 4, S< 8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 2, S<1, 32, 1, 8>, 8>, + DeviceBatchedGemmMultipleDGemmMultipleD_Xdl_CShuffle< Row, Col, ck::Tuple, Row, ck::Tuple, Row, F16, F16, F32, ck::Tuple, F16, F32, F32, ck::Tuple, F16, PassThrough, PassThrough, CDE0ElementOp, PassThrough, CDE1ElementOp, false, false, false, false, false, 1, 256, 128, 128, 32, 128, 32, 8, 8, 2, 32, 32, 1, 4, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 9, 4, S< 8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 2, S<1, 32, 1, 8>, 8>, + DeviceBatchedGemmMultipleDGemmMultipleD_Xdl_CShuffle< Row, Col, ck::Tuple, Row, ck::Tuple, Row, F16, F16, F32, ck::Tuple, F16, F32, F32, ck::Tuple, F16, PassThrough, PassThrough, CDE0ElementOp, PassThrough, CDE1ElementOp, false, false, false, false, false, 1, 256, 64, 256, 32, 128, 32, 8, 8, 2, 16, 16, 1, 16, 8, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 9, 4, S< 8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 8, S<1, 16, 1,16>, 8>, + DeviceBatchedGemmMultipleDGemmMultipleD_Xdl_CShuffle< Row, Col, ck::Tuple, Row, ck::Tuple, Row, F16, F16, F32, ck::Tuple, F16, F32, F32, ck::Tuple, F16, PassThrough, PassThrough, CDE0ElementOp, PassThrough, CDE1ElementOp, false, false, false, false, false, 1, 256, 64, 256, 32, 64, 32, 8, 8, 2, 16, 16, 1, 16, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 9, 4, S<16, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 4, S<1, 32, 1, 8>, 8>, + DeviceBatchedGemmMultipleDGemmMultipleD_Xdl_CShuffle< Row, Col, ck::Tuple, Row, ck::Tuple, Row, F16, F16, F32, ck::Tuple, F16, F32, F32, ck::Tuple, F16, PassThrough, PassThrough, CDE0ElementOp, PassThrough, CDE1ElementOp, false, false, false, false, false, 1, 256, 64, 256, 64, 128, 32, 8, 8, 2, 16, 16, 1, 16, 8, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 9, 4, S< 8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 8, S<1, 16, 1,16>, 8>, + DeviceBatchedGemmMultipleDGemmMultipleD_Xdl_CShuffle< Row, Col, ck::Tuple, Row, ck::Tuple, Row, F16, F16, F32, ck::Tuple, F16, F32, F32, ck::Tuple, F16, PassThrough, PassThrough, CDE0ElementOp, PassThrough, CDE1ElementOp, false, false, false, false, false, 1, 256, 64, 256, 64, 64, 32, 8, 8, 2, 16, 16, 1, 16, 4, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 9, 4, S<16, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 4, S<1, 32, 1, 8>, 8>, // Padded fallback kernel - DeviceBatchedGemmMultipleDGemmMultipleD_Xdl_CShuffle< Row, Col, ck::Tuple, Row, ck::Tuple, Row, F16, F16, F32, ck::Tuple, F16, F32, F32, ck::Tuple, F16, PassThrough, PassThrough, CDE0ElementOp, PassThrough, CDE1ElementOp, true, true, true, true, true, 1, 256, 128, 128, 64, 128, 32, 8, 8, 2, 32, 32, 1, 4, 4, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, false, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, false, S< 8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 2, S<1, 32, 1, 8>, 8>, - DeviceBatchedGemmMultipleDGemmMultipleD_Xdl_CShuffle< Row, Col, ck::Tuple, Row, ck::Tuple, Row, F16, F16, F32, ck::Tuple, F16, F32, F32, ck::Tuple, F16, PassThrough, PassThrough, CDE0ElementOp, PassThrough, CDE1ElementOp, true, true, true, true, true, 1, 256, 128, 64, 32, 128, 32, 8, 8, 2, 32, 32, 1, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S< 8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 2, S<1, 32, 1, 8>, 8> + DeviceBatchedGemmMultipleDGemmMultipleD_Xdl_CShuffle< Row, Col, ck::Tuple, Row, ck::Tuple, Row, F16, F16, F32, ck::Tuple, F16, F32, F32, ck::Tuple, F16, PassThrough, PassThrough, CDE0ElementOp, PassThrough, CDE1ElementOp, true, true, true, true, true, 1, 256, 128, 128, 64, 128, 32, 8, 8, 2, 32, 32, 1, 4, 4, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, false, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, false, 9, 4, S< 8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 2, S<1, 32, 1, 8>, 8>, + DeviceBatchedGemmMultipleDGemmMultipleD_Xdl_CShuffle< Row, Col, ck::Tuple, Row, ck::Tuple, Row, F16, F16, F32, ck::Tuple, F16, F32, F32, ck::Tuple, F16, PassThrough, PassThrough, CDE0ElementOp, PassThrough, CDE1ElementOp, true, true, true, true, true, 1, 256, 128, 64, 32, 128, 32, 8, 8, 2, 32, 32, 1, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 9, 4, S< 8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 2, S<1, 32, 1, 8>, 8> // clang-format on >; diff --git a/library/src/tensor_operation_instance/gpu/batched_gemm_add_relu_gemm_add/device_batched_gemm_add_relu_gemm_add_xdl_cshuffle_f16_f16_f16_f16_gmk_gnk_gon_gmo_instance.cpp b/library/src/tensor_operation_instance/gpu/batched_gemm_add_relu_gemm_add/device_batched_gemm_add_relu_gemm_add_xdl_cshuffle_f16_f16_f16_f16_gmk_gnk_gon_gmo_instance.cpp index f59b742534b2c234b2b4908f5e540f591b104619..37342e5f45b2465959e622f819e18ec88e744b67 100644 --- a/library/src/tensor_operation_instance/gpu/batched_gemm_add_relu_gemm_add/device_batched_gemm_add_relu_gemm_add_xdl_cshuffle_f16_f16_f16_f16_gmk_gnk_gon_gmo_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/batched_gemm_add_relu_gemm_add/device_batched_gemm_add_relu_gemm_add_xdl_cshuffle_f16_f16_f16_f16_gmk_gnk_gon_gmo_instance.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #include @@ -31,23 +31,25 @@ using CDE1ElementOp = ck::tensor_operation::element_wise::Add; using device_batched_gemm_add_relu_gemm_add_xdl_cshuffle_f16_f16_f16_f16_gmk_gnk_gon_gmo_instances = std::tuple< // clang-format off - //##################################################| A0Layout| B0Layout| D0Layout| B1Layout| D1sLayout| E1Layout| A0Data| B0Data| Acc0DataType| D0DataType| B1Data| Acc1CData| CShuffle| D1sData| E1Data| A0| B0| CDE0| B1| CDE1| PadGemm0M| PadGemm0N| PadGemm0K| PadGemm1N| PadGemm1K| NumGemm0K| Block| Gemm0| Gemm0| Gemm0| Gemm1| Gemm1| A0K1| B0K1|B1K1| MPer| NPer| Gemm0| Gemm0| Gemm1| A0BlockTransfer|A0BlockTransfer|A0BlockTransfer|A0BlockTransfer|A0BlockTransfer|A0BlockTransfer|A0BlockLds| B0BlockTransfer| B0BlockTransfer| B0BlockTransfer| B0BlockTransfer| B0BlockTransfer| B0BlockTransfer| B0BlockLds| B1BlockTransfer| B1BlockTransfer| B1BlockTransfer| B1BlockTransfer| B1BlockTransfer| B1BlockTransfer| B1BlockLds| C1Shuffle| C1Shuffle| CDE1BlockTransferClusterLengths| CDE1BlockTransfer| - //##################################################| | | | | | | Type| Type| Type| Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Elementwise| Elementwise| | | | | | Prefetch| Size| MPer| NPer| KPer| NPer| KPer| | | | XDL| XDL| MXdl| NXdl| NXdl| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| - //##################################################| | | | | | | | | | | | | | | | Operation| Operation| Operation| Operation| Operation| | | | | | Stage| | Block| Block| Block| Block| Block| | | | | | Per| Per| Per| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_AK1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| - //##################################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | Wave| Wave| Wave| | | | | | | | | | | | | | | | | | | | | | | | | | + //##################################################| A0Layout| B0Layout| D0Layout| B1Layout| D1sLayout| E1Layout| A0Data| B0Data| Acc0DataType| D0DataType| B1Data| Acc1CData| CShuffle| D1sData| E1Data| A0| B0| CDE0| B1| CDE1| PadGemm0M| PadGemm0N| PadGemm0K| PadGemm1N| PadGemm1K| NumGemm0K| Block| Gemm0| Gemm0| Gemm0| Gemm1| Gemm1| A0K1| B0K1|B1K1| MPer| NPer| Gemm0| Gemm0| Gemm1| A0BlockTransfer|A0BlockTransfer|A0BlockTransfer|A0BlockTransfer|A0BlockTransfer|A0BlockTransfer|A0BlockLds| B0BlockTransfer| B0BlockTransfer| B0BlockTransfer| B0BlockTransfer| B0BlockTransfer| B0BlockTransfer| B0BlockLds| CDE0BlockTransfer| CDE0BlockTransfer| B1BlockTransfer| B1BlockTransfer| B1BlockTransfer| B1BlockTransfer| B1BlockTransfer| B1BlockTransfer| B1BlockLds| C1Shuffle| C1Shuffle| CDE1BlockTransferClusterLengths| CDE1BlockTransfer| + //##################################################| | | | | | | Type| Type| Type| Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Elementwise| Elementwise| | | | | | Prefetch| Size| MPer| NPer| KPer| NPer| KPer| | | | XDL| XDL| MXdl| NXdl| NXdl| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| SrcVectorDim| SrcScalar| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| + //##################################################| | | | | | | | | | | | | | | | Operation| Operation| Operation| Operation| Operation| | | | | | Stage| | Block| Block| Block| Block| Block| | | | | | Per| Per| Per| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_AK1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | | PerVector| Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| + //##################################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | Wave| Wave| Wave| | | | | | | | | | | | | | | | | | | | | | | | | | | | + //generic + DeviceBatchedGemmMultipleDGemmMultipleD_Xdl_CShuffle< Row, Col, ck::Tuple, Col, ck::Tuple, Row, F16, F16, F32, ck::Tuple, F16, F32, F32, ck::Tuple, F16, PassThrough, PassThrough, CDE0ElementOp, PassThrough, CDE1ElementOp, true, true, true, true, true, 1, 256, 128, 64, 32, 128, 32, 8, 8, 4, 32, 32, 1, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 9, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 1, 2, S<1, 32, 1, 8>, 8>, // no padding - DeviceBatchedGemmMultipleDGemmMultipleD_Xdl_CShuffle< Row, Col, ck::Tuple, Col, ck::Tuple, Row, F16, F16, F32, ck::Tuple, F16, F32, F32, ck::Tuple, F16, PassThrough, PassThrough, CDE0ElementOp, PassThrough, CDE1ElementOp, false, false, false, false, false, 1, 256, 256, 128, 32, 128, 32, 8, 8, 4, 32, 32, 2, 4, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 1, 2, S<1, 32, 1, 8>, 8>, - DeviceBatchedGemmMultipleDGemmMultipleD_Xdl_CShuffle< Row, Col, ck::Tuple, Col, ck::Tuple, Row, F16, F16, F32, ck::Tuple, F16, F32, F32, ck::Tuple, F16, PassThrough, PassThrough, CDE0ElementOp, PassThrough, CDE1ElementOp, false, false, false, false, false, 1, 256, 128, 128, 64, 64, 32, 8, 8, 4, 32, 32, 1, 4, 2, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, false, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, false, 1, 2, S<1, 32, 1, 8>, 8>, - DeviceBatchedGemmMultipleDGemmMultipleD_Xdl_CShuffle< Row, Col, ck::Tuple, Col, ck::Tuple, Row, F16, F16, F32, ck::Tuple, F16, F32, F32, ck::Tuple, F16, PassThrough, PassThrough, CDE0ElementOp, PassThrough, CDE1ElementOp, false, false, false, false, false, 1, 256, 128, 128, 32, 64, 32, 8, 8, 4, 32, 32, 1, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 1, 2, S<1, 32, 1, 8>, 8>, - DeviceBatchedGemmMultipleDGemmMultipleD_Xdl_CShuffle< Row, Col, ck::Tuple, Col, ck::Tuple, Row, F16, F16, F32, ck::Tuple, F16, F32, F32, ck::Tuple, F16, PassThrough, PassThrough, CDE0ElementOp, PassThrough, CDE1ElementOp, false, false, false, false, false, 1, 256, 128, 128, 64, 128, 32, 8, 8, 4, 32, 32, 1, 4, 4, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, false, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, false, 1, 2, S<1, 32, 1, 8>, 8>, - DeviceBatchedGemmMultipleDGemmMultipleD_Xdl_CShuffle< Row, Col, ck::Tuple, Col, ck::Tuple, Row, F16, F16, F32, ck::Tuple, F16, F32, F32, ck::Tuple, F16, PassThrough, PassThrough, CDE0ElementOp, PassThrough, CDE1ElementOp, false, false, false, false, false, 1, 256, 128, 128, 32, 128, 32, 8, 8, 4, 32, 32, 1, 4, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 1, 2, S<1, 32, 1, 8>, 8>, - DeviceBatchedGemmMultipleDGemmMultipleD_Xdl_CShuffle< Row, Col, ck::Tuple, Col, ck::Tuple, Row, F16, F16, F32, ck::Tuple, F16, F32, F32, ck::Tuple, F16, PassThrough, PassThrough, CDE0ElementOp, PassThrough, CDE1ElementOp, false, false, false, false, false, 1, 256, 64, 256, 32, 128, 32, 8, 8, 4, 16, 16, 1, 16, 8, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 1, 8, S<1, 16, 1,16>, 8>, - DeviceBatchedGemmMultipleDGemmMultipleD_Xdl_CShuffle< Row, Col, ck::Tuple, Col, ck::Tuple, Row, F16, F16, F32, ck::Tuple, F16, F32, F32, ck::Tuple, F16, PassThrough, PassThrough, CDE0ElementOp, PassThrough, CDE1ElementOp, false, false, false, false, false, 1, 256, 64, 256, 32, 64, 32, 8, 8, 4, 16, 16, 1, 16, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 1, 4, S<1, 32, 1, 8>, 8>, - DeviceBatchedGemmMultipleDGemmMultipleD_Xdl_CShuffle< Row, Col, ck::Tuple, Col, ck::Tuple, Row, F16, F16, F32, ck::Tuple, F16, F32, F32, ck::Tuple, F16, PassThrough, PassThrough, CDE0ElementOp, PassThrough, CDE1ElementOp, false, false, false, false, false, 1, 256, 64, 256, 64, 128, 32, 8, 8, 4, 16, 16, 1, 16, 8, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 1, 8, S<1, 16, 1,16>, 8>, - DeviceBatchedGemmMultipleDGemmMultipleD_Xdl_CShuffle< Row, Col, ck::Tuple, Col, ck::Tuple, Row, F16, F16, F32, ck::Tuple, F16, F32, F32, ck::Tuple, F16, PassThrough, PassThrough, CDE0ElementOp, PassThrough, CDE1ElementOp, false, false, false, false, false, 1, 256, 64, 256, 64, 64, 32, 8, 8, 4, 16, 16, 1, 16, 4, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 1, 4, S<1, 32, 1, 8>, 8>, + DeviceBatchedGemmMultipleDGemmMultipleD_Xdl_CShuffle< Row, Col, ck::Tuple, Col, ck::Tuple, Row, F16, F16, F32, ck::Tuple, F16, F32, F32, ck::Tuple, F16, PassThrough, PassThrough, CDE0ElementOp, PassThrough, CDE1ElementOp, false, false, false, false, false, 1, 256, 256, 128, 32, 128, 32, 8, 8, 4, 32, 32, 2, 4, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 9, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 1, 2, S<1, 32, 1, 8>, 8>, + DeviceBatchedGemmMultipleDGemmMultipleD_Xdl_CShuffle< Row, Col, ck::Tuple, Col, ck::Tuple, Row, F16, F16, F32, ck::Tuple, F16, F32, F32, ck::Tuple, F16, PassThrough, PassThrough, CDE0ElementOp, PassThrough, CDE1ElementOp, false, false, false, false, false, 1, 256, 128, 128, 64, 64, 32, 8, 8, 4, 32, 32, 1, 4, 2, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, false, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 9, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, false, 1, 2, S<1, 32, 1, 8>, 8>, + DeviceBatchedGemmMultipleDGemmMultipleD_Xdl_CShuffle< Row, Col, ck::Tuple, Col, ck::Tuple, Row, F16, F16, F32, ck::Tuple, F16, F32, F32, ck::Tuple, F16, PassThrough, PassThrough, CDE0ElementOp, PassThrough, CDE1ElementOp, false, false, false, false, false, 1, 256, 128, 128, 32, 64, 32, 8, 8, 4, 32, 32, 1, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 9, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 1, 2, S<1, 32, 1, 8>, 8>, + DeviceBatchedGemmMultipleDGemmMultipleD_Xdl_CShuffle< Row, Col, ck::Tuple, Col, ck::Tuple, Row, F16, F16, F32, ck::Tuple, F16, F32, F32, ck::Tuple, F16, PassThrough, PassThrough, CDE0ElementOp, PassThrough, CDE1ElementOp, false, false, false, false, false, 1, 256, 128, 128, 64, 128, 32, 8, 8, 4, 32, 32, 1, 4, 4, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, false, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 9, 4, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, false, 1, 2, S<1, 32, 1, 8>, 8>, + DeviceBatchedGemmMultipleDGemmMultipleD_Xdl_CShuffle< Row, Col, ck::Tuple, Col, ck::Tuple, Row, F16, F16, F32, ck::Tuple, F16, F32, F32, ck::Tuple, F16, PassThrough, PassThrough, CDE0ElementOp, PassThrough, CDE1ElementOp, false, false, false, false, false, 1, 256, 128, 128, 32, 128, 32, 8, 8, 4, 32, 32, 1, 4, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 9, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 1, 2, S<1, 32, 1, 8>, 8>, + DeviceBatchedGemmMultipleDGemmMultipleD_Xdl_CShuffle< Row, Col, ck::Tuple, Col, ck::Tuple, Row, F16, F16, F32, ck::Tuple, F16, F32, F32, ck::Tuple, F16, PassThrough, PassThrough, CDE0ElementOp, PassThrough, CDE1ElementOp, false, false, false, false, false, 1, 256, 64, 256, 32, 128, 32, 8, 8, 4, 16, 16, 1, 16, 8, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 9, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 1, 8, S<1, 16, 1,16>, 8>, + DeviceBatchedGemmMultipleDGemmMultipleD_Xdl_CShuffle< Row, Col, ck::Tuple, Col, ck::Tuple, Row, F16, F16, F32, ck::Tuple, F16, F32, F32, ck::Tuple, F16, PassThrough, PassThrough, CDE0ElementOp, PassThrough, CDE1ElementOp, false, false, false, false, false, 1, 256, 64, 256, 32, 64, 32, 8, 8, 4, 16, 16, 1, 16, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 9, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 1, 4, S<1, 32, 1, 8>, 8>, + DeviceBatchedGemmMultipleDGemmMultipleD_Xdl_CShuffle< Row, Col, ck::Tuple, Col, ck::Tuple, Row, F16, F16, F32, ck::Tuple, F16, F32, F32, ck::Tuple, F16, PassThrough, PassThrough, CDE0ElementOp, PassThrough, CDE1ElementOp, false, false, false, false, false, 1, 256, 64, 256, 64, 128, 32, 8, 8, 4, 16, 16, 1, 16, 8, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 9, 4, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 1, 8, S<1, 16, 1,16>, 8>, + DeviceBatchedGemmMultipleDGemmMultipleD_Xdl_CShuffle< Row, Col, ck::Tuple, Col, ck::Tuple, Row, F16, F16, F32, ck::Tuple, F16, F32, F32, ck::Tuple, F16, PassThrough, PassThrough, CDE0ElementOp, PassThrough, CDE1ElementOp, false, false, false, false, false, 1, 256, 64, 256, 64, 64, 32, 8, 8, 4, 16, 16, 1, 16, 4, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 9, 4, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 1, 4, S<1, 32, 1, 8>, 8>, // Padded fallback kernel - DeviceBatchedGemmMultipleDGemmMultipleD_Xdl_CShuffle< Row, Col, ck::Tuple, Col, ck::Tuple, Row, F16, F16, F32, ck::Tuple, F16, F32, F32, ck::Tuple, F16, PassThrough, PassThrough, CDE0ElementOp, PassThrough, CDE1ElementOp, true, true, true, true, true, 1, 256, 128, 128, 64, 128, 32, 8, 8, 4, 32, 32, 1, 4, 4, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, false, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, false, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, false, 1, 2, S<1, 32, 1, 8>, 8>, - DeviceBatchedGemmMultipleDGemmMultipleD_Xdl_CShuffle< Row, Col, ck::Tuple, Col, ck::Tuple, Row, F16, F16, F32, ck::Tuple, F16, F32, F32, ck::Tuple, F16, PassThrough, PassThrough, CDE0ElementOp, PassThrough, CDE1ElementOp, true, true, true, true, true, 1, 256, 128, 64, 32, 128, 32, 8, 8, 4, 32, 32, 1, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 1, 2, S<1, 32, 1, 8>, 8> + DeviceBatchedGemmMultipleDGemmMultipleD_Xdl_CShuffle< Row, Col, ck::Tuple, Col, ck::Tuple, Row, F16, F16, F32, ck::Tuple, F16, F32, F32, ck::Tuple, F16, PassThrough, PassThrough, CDE0ElementOp, PassThrough, CDE1ElementOp, true, true, true, true, true, 1, 256, 128, 128, 64, 128, 32, 8, 8, 4, 32, 32, 1, 4, 4, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, false, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, false, 9, 4, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, false, 1, 2, S<1, 32, 1, 8>, 8>, + DeviceBatchedGemmMultipleDGemmMultipleD_Xdl_CShuffle< Row, Col, ck::Tuple, Col, ck::Tuple, Row, F16, F16, F32, ck::Tuple, F16, F32, F32, ck::Tuple, F16, PassThrough, PassThrough, CDE0ElementOp, PassThrough, CDE1ElementOp, true, true, true, true, true, 1, 256, 128, 64, 32, 128, 32, 8, 8, 4, 32, 32, 1, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 9, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 1, 2, S<1, 32, 1, 8>, 8> // clang-format on >; diff --git a/library/src/tensor_operation_instance/gpu/batched_gemm_bias_permute/CMakeLists.txt b/library/src/tensor_operation_instance/gpu/batched_gemm_bias_permute/CMakeLists.txt index cd9c95c066e6f14c4ee6904255e333db1792dba9..b0f37e68fd88822a98da3b50aabbfb7d71700d65 100644 --- a/library/src/tensor_operation_instance/gpu/batched_gemm_bias_permute/CMakeLists.txt +++ b/library/src/tensor_operation_instance/gpu/batched_gemm_bias_permute/CMakeLists.txt @@ -1,4 +1,5 @@ +if(DTYPES MATCHES "fp16" OR NOT DEFINED DTYPES) add_instance_library(device_batched_gemm_bias_permute_instance device_batched_gemm_bias_permute_m2_n3_k1_xdl_c_shuffle_f16_f16_f16_f16_instance.cpp ) - +endif() diff --git a/library/src/tensor_operation_instance/gpu/batched_gemm_bias_permute/device_batched_gemm_bias_permute_m2_n3_k1_xdl_c_shuffle_f16_f16_f16_f16_instance.cpp b/library/src/tensor_operation_instance/gpu/batched_gemm_bias_permute/device_batched_gemm_bias_permute_m2_n3_k1_xdl_c_shuffle_f16_f16_f16_f16_instance.cpp index 04a748f45513c56ce23781dd08bcf4c2bae03331..a0afaabbc7bbff785c7a3e6cba884a0bb1f768e4 100644 --- a/library/src/tensor_operation_instance/gpu/batched_gemm_bias_permute/device_batched_gemm_bias_permute_m2_n3_k1_xdl_c_shuffle_f16_f16_f16_f16_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/batched_gemm_bias_permute/device_batched_gemm_bias_permute_m2_n3_k1_xdl_c_shuffle_f16_f16_f16_f16_instance.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. // This (ifndef) is a hack to use customized behavior for buffer load rather than using default // setting Don't use this hack unless absolutely necessary! diff --git a/library/src/tensor_operation_instance/gpu/batched_gemm_gemm/CMakeLists.txt b/library/src/tensor_operation_instance/gpu/batched_gemm_gemm/CMakeLists.txt index 865a31e79a5dfdeb3af54a7f59fc40dc4e90070b..cdb1a539016505ed9f4a92c80dcdf00badbe4855 100644 --- a/library/src/tensor_operation_instance/gpu/batched_gemm_gemm/CMakeLists.txt +++ b/library/src/tensor_operation_instance/gpu/batched_gemm_gemm/CMakeLists.txt @@ -1,4 +1,6 @@ +if(DTYPES MATCHES "fp16" OR NOT DEFINED DTYPES) add_instance_library(device_batched_gemm_gemm_instance device_batched_gemm_gemm_xdl_cshuffle_f16_f16_f16_f16_gmk_gnk_gno_gmo_instance.cpp device_batched_gemm_gemm_xdl_cshuffle_f16_f16_f16_f16_gmk_gnk_gon_gmo_instance.cpp ) +endif() diff --git a/library/src/tensor_operation_instance/gpu/batched_gemm_gemm/device_batched_gemm_gemm_xdl_cshuffle_f16_f16_f16_f16_gmk_gnk_gno_gmo_instance.cpp b/library/src/tensor_operation_instance/gpu/batched_gemm_gemm/device_batched_gemm_gemm_xdl_cshuffle_f16_f16_f16_f16_gmk_gnk_gno_gmo_instance.cpp index 9b96194c874c3daea7c9c4eb931368ce5645b50e..67dfc4cd35066053c7b9936a76c402454cf1e99b 100644 --- a/library/src/tensor_operation_instance/gpu/batched_gemm_gemm/device_batched_gemm_gemm_xdl_cshuffle_f16_f16_f16_f16_gmk_gnk_gno_gmo_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/batched_gemm_gemm/device_batched_gemm_gemm_xdl_cshuffle_f16_f16_f16_f16_gmk_gnk_gno_gmo_instance.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #include diff --git a/library/src/tensor_operation_instance/gpu/batched_gemm_gemm/device_batched_gemm_gemm_xdl_cshuffle_f16_f16_f16_f16_gmk_gnk_gon_gmo_instance.cpp b/library/src/tensor_operation_instance/gpu/batched_gemm_gemm/device_batched_gemm_gemm_xdl_cshuffle_f16_f16_f16_f16_gmk_gnk_gon_gmo_instance.cpp index 0713dfcd99ca183d8f64cd23ac6067be990b39e7..9001c901cfbecb473c6a0e4f70a9db6526e2a4cb 100644 --- a/library/src/tensor_operation_instance/gpu/batched_gemm_gemm/device_batched_gemm_gemm_xdl_cshuffle_f16_f16_f16_f16_gmk_gnk_gon_gmo_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/batched_gemm_gemm/device_batched_gemm_gemm_xdl_cshuffle_f16_f16_f16_f16_gmk_gnk_gon_gmo_instance.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #include diff --git a/library/src/tensor_operation_instance/gpu/batched_gemm_multi_d/CMakeLists.txt b/library/src/tensor_operation_instance/gpu/batched_gemm_multi_d/CMakeLists.txt new file mode 100644 index 0000000000000000000000000000000000000000..444c93b118894662d68f628dd9280308dbca7102 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/batched_gemm_multi_d/CMakeLists.txt @@ -0,0 +1,25 @@ +# ONLY DL_KERNELS +if(DL_KERNELS) + set(BATCHED_GEMM_MULTID_INSTANCES) + if(DTYPES MATCHES "fp16" OR NOT DEFINED DTYPES) + list(APPEND BATCHED_GEMM_MULTID_INSTANCES device_batched_gemm_multi_d_dl_f16_f16_f16_gmk_gkn_gmn_instance.cpp) + list(APPEND BATCHED_GEMM_MULTID_INSTANCES device_batched_gemm_multi_d_dl_f16_f16_f16_gmk_gnk_gmn_instance.cpp) + list(APPEND BATCHED_GEMM_MULTID_INSTANCES device_batched_gemm_multi_d_dl_f16_f16_f16_gkm_gkn_gmn_instance.cpp) + list(APPEND BATCHED_GEMM_MULTID_INSTANCES device_batched_gemm_multi_d_dl_f16_f16_f16_gkm_gnk_gmn_instance.cpp) + list(APPEND BATCHED_GEMM_MULTID_INSTANCES device_batched_gemm_multi_d_dl_f16_f16_f16_gmk_gkn_gmn_irregular_instance.cpp) + list(APPEND BATCHED_GEMM_MULTID_INSTANCES device_batched_gemm_multi_d_dl_f16_f16_f16_gmk_gnk_gmn_irregular_instance.cpp) + list(APPEND BATCHED_GEMM_MULTID_INSTANCES device_batched_gemm_multi_d_dl_f16_f16_f16_gkm_gkn_gmn_irregular_instance.cpp) + list(APPEND BATCHED_GEMM_MULTID_INSTANCES device_batched_gemm_multi_d_dl_f16_f16_f16_gkm_gnk_gmn_irregular_instance.cpp) + endif() + if(DTYPES MATCHES "int8" OR NOT DEFINED DTYPES) + list(APPEND BATCHED_GEMM_MULTID_INSTANCES device_batched_gemm_multi_d_dl_i8_i8_i8_gmk_gkn_gmn_instance.cpp) + list(APPEND BATCHED_GEMM_MULTID_INSTANCES device_batched_gemm_multi_d_dl_i8_i8_i8_gmk_gnk_gmn_instance.cpp) + list(APPEND BATCHED_GEMM_MULTID_INSTANCES device_batched_gemm_multi_d_dl_i8_i8_i8_gkm_gkn_gmn_instance.cpp) + list(APPEND BATCHED_GEMM_MULTID_INSTANCES device_batched_gemm_multi_d_dl_i8_i8_i8_gkm_gnk_gmn_instance.cpp) + list(APPEND BATCHED_GEMM_MULTID_INSTANCES device_batched_gemm_multi_d_dl_i8_i8_i8_gmk_gkn_gmn_irregular_instance.cpp) + list(APPEND BATCHED_GEMM_MULTID_INSTANCES device_batched_gemm_multi_d_dl_i8_i8_i8_gmk_gnk_gmn_irregular_instance.cpp) + list(APPEND BATCHED_GEMM_MULTID_INSTANCES device_batched_gemm_multi_d_dl_i8_i8_i8_gkm_gkn_gmn_irregular_instance.cpp) + list(APPEND BATCHED_GEMM_MULTID_INSTANCES device_batched_gemm_multi_d_dl_i8_i8_i8_gkm_gnk_gmn_irregular_instance.cpp) + endif() + add_instance_library(device_batched_gemm_multi_d_instance ${BATCHED_GEMM_MULTID_INSTANCES}) +endif() diff --git a/library/src/tensor_operation_instance/gpu/batched_gemm_multi_d/device_batched_gemm_multi_d_dl_f16_f16_f16_gkm_gkn_gmn_instance.cpp b/library/src/tensor_operation_instance/gpu/batched_gemm_multi_d/device_batched_gemm_multi_d_dl_f16_f16_f16_gkm_gkn_gmn_instance.cpp new file mode 100644 index 0000000000000000000000000000000000000000..b20e8a86b1a2474a08a1186f65dfe93adafb44fd --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/batched_gemm_multi_d/device_batched_gemm_multi_d_dl_f16_f16_f16_gkm_gkn_gmn_instance.cpp @@ -0,0 +1,95 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/device_batched_gemm_multi_d.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_batched_gemm_multiple_d_dl.hpp" +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +using F16 = ck::half_t; +using F32 = float; + +using Row = ck::tensor_layout::gemm::RowMajor; +using Col = ck::tensor_layout::gemm::ColumnMajor; + +template +using S = ck::Sequence; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; +using Empty_Tuple = ck::Tuple<>; + +static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default; + +// Compilation parameters for a[k, m] * b[k, n] = c[m, n] +using device_batched_gemm_multi_d_dl_f16_f16_f16_gkm_gkn_gmn_instances = std::tuple< + // clang-format off + // ##########################| ALayout| BLayout| DsLayout| CLayout| AData| BData| AccData| DsData| CData| A| B| C| GEMM| Block| MPer| NPer| K0Per| K1| M1Per| N1Per| KPer| M11N11Thread| M11N11Thread| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| CThreadTransfer| CThreadTransfer| CThreadTransfer| + // ##########################| | | | | Type| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise| Specialization| Size| Block| Block| Block| | ThreadM111| ThreadN111| Thread| ClusterM110Xs| ClusterN110Xs| ThreadSliceLengths| ThreadClusterLengths| ThreadCluster| SrcAccess| SrcVectorTensor| SrcVectorTensor| DstVectorTensor| ThreadSliceLengths| ThreadClusterLengths| ThreadCluster| SrcAccess| SrcVectorTensor| SrcVectorTensor| DstVectorTensor| SrcDstAccess| SrcDstVectorDim| DstScalarPerVector| + // ##########################| | | | | | | | | | Operation| Operation| Operation| | | | | | | | | | | | K0_M0_M1_K1| K0_M0_M1_K1| ArrangeOrder| Order| Lengths_K0_M0_M1_K1| ContiguousDimOrder| Lengths_K0_M0_M1_K1| K0_M0_M1_K1| K0_M0_M1_K1| ArrangeOrder| Order| Lengths_K0_M0_M1_K1| ContiguousDimOrder| Lengths_K0_M0_M1_K1| Order| | | + // ##########################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + // MPerBlock=128, NPerBlock=128 + DeviceBatchedGemmMultipleD_Dl< Col, Row, Empty_Tuple, Row, F16, F16, F32, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 128, 128, 16, 2, 4, 4, 1, S<8, 2>, S<8, 2>, S<2, 1, 4, 2>, S<8, 1, 32, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 2>, S<2, 1, 4, 2>, S<8, 1, 32, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 2>, S<0, 1, 2, 3, 4, 5>, 5, 4>, + DeviceBatchedGemmMultipleD_Dl< Col, Row, Empty_Tuple, Row, F16, F16, F32, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 128, 128, 16, 2, 4, 4, 1, S<4, 4>, S<4, 4>, S<2, 1, 4, 2>, S<8, 1, 32, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 2>, S<2, 1, 4, 2>, S<8, 1, 32, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 2>, S<0, 1, 2, 3, 4, 5>, 5, 4>, + DeviceBatchedGemmMultipleD_Dl< Col, Row, Empty_Tuple, Row, F16, F16, F32, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 128, 128, 16, 2, 4, 4, 1, S<2, 8>, S<2, 8>, S<2, 1, 4, 2>, S<8, 1, 32, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 2>, S<2, 1, 4, 2>, S<8, 1, 32, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 2>, S<0, 1, 2, 3, 4, 5>, 5, 4>, + // MPerBlock=128, NPerBlock=64 + DeviceBatchedGemmMultipleD_Dl< Col, Row, Empty_Tuple, Row, F16, F16, F32, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmDefault, 128, 128, 64, 16, 2, 4, 4, 1, S<8, 2>, S<4, 2>, S<2, 1, 8, 2>, S<8, 1, 16, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 2>, S<2, 1, 8, 2>, S<8, 1, 8, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 2>, S<0, 1, 2, 3, 4, 5>, 5, 4>, + DeviceBatchedGemmMultipleD_Dl< Col, Row, Empty_Tuple, Row, F16, F16, F32, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmDefault, 128, 128, 64, 16, 2, 4, 4, 1, S<2, 8>, S<2, 4>, S<2, 1, 8, 2>, S<8, 1, 16, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 2>, S<2, 1, 8, 2>, S<8, 1, 8, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 2>, S<0, 1, 2, 3, 4, 5>, 5, 4>, + // MPerBlock=64, NPerBlock=128 + DeviceBatchedGemmMultipleD_Dl< Col, Row, Empty_Tuple, Row, F16, F16, F32, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmDefault, 128, 64, 128, 16, 2, 4, 4, 1, S<4, 2>, S<8, 2>, S<2, 1, 8, 2>, S<8, 1, 8, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 2>, S<2, 1, 8, 2>, S<8, 1, 16, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 2>, S<0, 1, 2, 3, 4, 5>, 5, 4>, + DeviceBatchedGemmMultipleD_Dl< Col, Row, Empty_Tuple, Row, F16, F16, F32, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmDefault, 128, 64, 128, 16, 2, 4, 4, 1, S<2, 4>, S<2, 8>, S<2, 1, 8, 2>, S<8, 1, 8, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 2>, S<2, 1, 8, 2>, S<8, 1, 16, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 2>, S<0, 1, 2, 3, 4, 5>, 5, 4>, + // MPerBlock=64, NPerBlock=64 + DeviceBatchedGemmMultipleD_Dl< Col, Row, Empty_Tuple, Row, F16, F16, F32, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmDefault, 64, 64, 64, 8, 2, 4, 4, 1, S<4, 2>, S<4, 2>, S<2, 1, 4, 2>, S<4, 1, 16, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 2>, S<2, 1, 4, 2>, S<4, 1, 16, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 2>, S<0, 1, 2, 3, 4, 5>, 5, 4>, + DeviceBatchedGemmMultipleD_Dl< Col, Row, Empty_Tuple, Row, F16, F16, F32, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmDefault, 64, 64, 64, 8, 2, 4, 4, 1, S<2, 4>, S<2, 4>, S<2, 1, 4, 2>, S<4, 1, 16, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 2>, S<2, 1, 4, 2>, S<4, 1, 16, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 2>, S<0, 1, 2, 3, 4, 5>, 5, 4>, + DeviceBatchedGemmMultipleD_Dl< Col, Row, Empty_Tuple, Row, F16, F16, F32, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmDefault, 64, 64, 64, 8, 2, 4, 4, 1, S<8, 1>, S<4, 2>, S<2, 1, 4, 2>, S<4, 1, 16, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 2>, S<2, 1, 4, 2>, S<4, 1, 16, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 2>, S<0, 1, 2, 3, 4, 5>, 5, 4>, + DeviceBatchedGemmMultipleD_Dl< Col, Row, Empty_Tuple, Row, F16, F16, F32, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmDefault, 64, 64, 64, 8, 2, 4, 4, 1, S<4, 2>, S<8, 1>, S<2, 1, 4, 2>, S<4, 1, 16, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 2>, S<2, 1, 4, 2>, S<4, 1, 16, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 2>, S<0, 1, 2, 3, 4, 5>, 5, 4>, + // MPerBlock=16, NPerBlock=64 + DeviceBatchedGemmMultipleD_Dl< Col, Row, Empty_Tuple, Row, F16, F16, F32, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmDefault, 64, 16, 64, 16, 2, 1, 4, 1, S<4, 2>, S<4, 2>, S<1, 1, 4, 2>, S<16, 1, 4, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 2>, S<4, 1, 4, 2>, S<4, 1, 16, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 2>, S<0, 1, 2, 3, 4, 5>, 5, 4>, + // MPerBlock=64, NPerBlock=16 + DeviceBatchedGemmMultipleD_Dl< Col, Row, Empty_Tuple, Row, F16, F16, F32, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmDefault, 64, 64, 16, 16, 2, 4, 1, 1, S<4, 2>, S<4, 2>, S<4, 1, 4, 2>, S<4, 1, 16, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 2>, S<1, 1, 4, 2>, S<16, 1, 4, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 2>, S<0, 1, 2, 3, 4, 5>, 5, 1>, + // MPerBlock=16, NPerBlock=16 + DeviceBatchedGemmMultipleD_Dl< Col, Row, Empty_Tuple, Row, F16, F16, F32, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmDefault, 16, 16, 16, 16, 2, 2, 2, 1, S<2, 2>, S<2, 2>, S<4, 1, 4, 2>, S<4, 1, 4, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 2>, S<4, 1, 4, 2>, S<4, 1, 4, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 2>, S<0, 1, 2, 3, 4, 5>, 5, 2>, + DeviceBatchedGemmMultipleD_Dl< Col, Row, Empty_Tuple, Row, F16, F16, F32, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmDefault, 16, 16, 16, 16, 2, 2, 2, 1, S<1, 4>, S<1, 4>, S<4, 1, 4, 2>, S<4, 1, 4, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 2>, S<4, 1, 4, 2>, S<4, 1, 4, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 2>, S<0, 1, 2, 3, 4, 5>, 5, 2>, + // MPerBlock=8, NPerBlock=64 + DeviceBatchedGemmMultipleD_Dl< Col, Row, Empty_Tuple, Row, F16, F16, F32, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmDefault, 64, 8, 64, 32, 2, 1, 2, 1, S<4, 1>, S<8, 2>, S<1, 1, 4, 2>, S<32, 1, 2, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 2>, S<8, 1, 4, 2>, S<4, 1, 16, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 2>, S<0, 1, 2, 3, 4, 5>, 5, 2>, + DeviceBatchedGemmMultipleD_Dl< Col, Row, Empty_Tuple, Row, F16, F16, F32, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmDefault, 64, 8, 64, 32, 2, 1, 2, 1, S<2, 2>, S<8, 2>, S<1, 1, 4, 2>, S<32, 1, 2, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 2>, S<8, 1, 4, 2>, S<4, 1, 16, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 2>, S<0, 1, 2, 3, 4, 5>, 5, 2>, + // MPerBlock=64, NPerBlock=8 + DeviceBatchedGemmMultipleD_Dl< Col, Row, Empty_Tuple, Row, F16, F16, F32, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmDefault, 64, 64, 8, 32, 2, 2, 1, 1, S<8, 2>, S<4, 1>, S<8, 1, 4, 2>, S<4, 1, 16, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 2>, S<1, 1, 4, 2>, S<32, 1, 2, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 2>, S<0, 1, 2, 3, 4, 5>, 5, 1>, + DeviceBatchedGemmMultipleD_Dl< Col, Row, Empty_Tuple, Row, F16, F16, F32, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmDefault, 64, 64, 8, 32, 2, 2, 1, 1, S<8, 2>, S<2, 2>, S<8, 1, 4, 2>, S<4, 1, 16, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 2>, S<1, 1, 4, 2>, S<32, 1, 2, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 2>, S<0, 1, 2, 3, 4, 5>, 5, 1>, + // MPerBlock=8, NPerBlock=8 + DeviceBatchedGemmMultipleD_Dl< Col, Row, Empty_Tuple, Row, F16, F16, F32, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmDefault, 8, 8, 8, 4, 2, 1, 2, 1, S<4, 1>, S<2, 1>, S<1, 1, 4, 2>, S<4, 1, 2, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 2>, S<1, 1, 4, 2>, S<4, 1, 2, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 2>, S<0, 1, 2, 3, 4, 5>, 5, 2>, + DeviceBatchedGemmMultipleD_Dl< Col, Row, Empty_Tuple, Row, F16, F16, F32, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmDefault, 8, 8, 8, 4, 2, 1, 2, 1, S<1, 4>, S<1, 2>, S<1, 1, 4, 2>, S<4, 1, 2, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 2>, S<1, 1, 4, 2>, S<4, 1, 2, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 2>, S<0, 1, 2, 3, 4, 5>, 5, 2>, + DeviceBatchedGemmMultipleD_Dl< Col, Row, Empty_Tuple, Row, F16, F16, F32, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmDefault, 8, 8, 8, 4, 2, 2, 1, 1, S<2, 1>, S<4, 1>, S<1, 1, 4, 2>, S<4, 1, 2, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 2>, S<1, 1, 4, 2>, S<4, 1, 2, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 2>, S<0, 1, 2, 3, 4, 5>, 5, 1>, + DeviceBatchedGemmMultipleD_Dl< Col, Row, Empty_Tuple, Row, F16, F16, F32, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmDefault, 8, 8, 8, 4, 2, 2, 1, 1, S<1, 2>, S<1, 4>, S<1, 1, 4, 2>, S<4, 1, 2, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 2>, S<1, 1, 4, 2>, S<4, 1, 2, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 2>, S<0, 1, 2, 3, 4, 5>, 5, 1> + // clang-format on + >; + +void add_device_batched_gemm_multi_d_dl_f16_f16_f16_gkm_gkn_gmn_instances( + std::vector>>& instances) +{ + add_device_operation_instances( + instances, device_batched_gemm_multi_d_dl_f16_f16_f16_gkm_gkn_gmn_instances{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/batched_gemm_multi_d/device_batched_gemm_multi_d_dl_f16_f16_f16_gkm_gkn_gmn_irregular_instance.cpp b/library/src/tensor_operation_instance/gpu/batched_gemm_multi_d/device_batched_gemm_multi_d_dl_f16_f16_f16_gkm_gkn_gmn_irregular_instance.cpp new file mode 100644 index 0000000000000000000000000000000000000000..e98fd24eeb759fb8bdde4d6982cdc219522e065d --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/batched_gemm_multi_d/device_batched_gemm_multi_d_dl_f16_f16_f16_gkm_gkn_gmn_irregular_instance.cpp @@ -0,0 +1,84 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/device_batched_gemm_multi_d.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_batched_gemm_multiple_d_dl.hpp" +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +using F16 = ck::half_t; +using F32 = float; + +using Row = ck::tensor_layout::gemm::RowMajor; +using Col = ck::tensor_layout::gemm::ColumnMajor; + +template +using S = ck::Sequence; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; +using Empty_Tuple = ck::Tuple<>; + +static constexpr auto GemmMNPadding = ck::tensor_operation::device::GemmSpecialization::MNPadding; + +// Compilation parameters for a[k, m] * b[k, n] = c[m, n] +using device_batched_gemm_multi_d_dl_f16_f16_f16_gkm_gkn_gmn_irregular_instances = std::tuple< + // clang-format off + // ##########################| ALayout| BLayout| DsLayout| CLayout| AData| BData| AccData| DsData| CData| A| B| C| GEMM| Block| MPer| NPer| K0Per| K1| M1Per| N1Per| KPer| M11N11Thread| M11N11Thread| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| CThreadTransfer| CThreadTransfer| CThreadTransfer| + // ##########################| | | | | Type| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise| Specialization| Size| Block| Block| Block| | ThreadM111| ThreadN111| Thread| ClusterM110Xs| ClusterN110Xs| ThreadSliceLengths| ThreadClusterLengths| ThreadCluster| SrcAccess| SrcVectorTensor| SrcVectorTensor| DstVectorTensor| ThreadSliceLengths| ThreadClusterLengths| ThreadCluster| SrcAccess| SrcVectorTensor| SrcVectorTensor| DstVectorTensor| SrcDstAccess| SrcDstVectorDim| DstScalarPerVector| + // ##########################| | | | | | | | | | Operation| Operation| Operation| | | | | | | | | | | | K0_M0_M1_K1| K0_M0_M1_K1| ArrangeOrder| Order| Lengths_K0_M0_M1_K1| ContiguousDimOrder| Lengths_K0_M0_M1_K1| K0_M0_M1_K1| K0_M0_M1_K1| ArrangeOrder| Order| Lengths_K0_M0_M1_K1| ContiguousDimOrder| Lengths_K0_M0_M1_K1| Order| | | + // ##########################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + // MPerBlock=128, NPerBlock=128 + // MPerBlock=128, NPerBlock=128 + DeviceBatchedGemmMultipleD_Dl< Col, Row, Empty_Tuple, Row, F16, F16, F32, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmMNPadding, 256, 128, 128, 16, 2, 4, 4, 1, S<8, 2>, S<8, 2>, S<2, 1, 4, 2>, S<8, 1, 32, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 2>, S<2, 1, 4, 2>, S<8, 1, 32, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 2>, S<0, 1, 2, 3, 4, 5>, 5, 1>, + DeviceBatchedGemmMultipleD_Dl< Col, Row, Empty_Tuple, Row, F16, F16, F32, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmMNPadding, 256, 128, 128, 16, 2, 4, 4, 1, S<4, 4>, S<4, 4>, S<2, 1, 4, 2>, S<8, 1, 32, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 2>, S<2, 1, 4, 2>, S<8, 1, 32, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 2>, S<0, 1, 2, 3, 4, 5>, 5, 1>, + DeviceBatchedGemmMultipleD_Dl< Col, Row, Empty_Tuple, Row, F16, F16, F32, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmMNPadding, 256, 128, 128, 16, 2, 4, 4, 1, S<2, 8>, S<2, 8>, S<2, 1, 4, 2>, S<8, 1, 32, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 2>, S<2, 1, 4, 2>, S<8, 1, 32, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 2>, S<0, 1, 2, 3, 4, 5>, 5, 1>, + // MPerBlock=64, NPerBlock=64 + DeviceBatchedGemmMultipleD_Dl< Col, Row, Empty_Tuple, Row, F16, F16, F32, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmMNPadding, 64, 64, 64, 8, 2, 4, 4, 1, S<4, 2>, S<4, 2>, S<2, 1, 4, 2>, S<4, 1, 16, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 2>, S<2, 1, 4, 2>, S<4, 1, 16, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 2>, S<0, 1, 2, 3, 4, 5>, 5, 1>, + DeviceBatchedGemmMultipleD_Dl< Col, Row, Empty_Tuple, Row, F16, F16, F32, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmMNPadding, 64, 64, 64, 8, 2, 4, 4, 1, S<2, 4>, S<2, 4>, S<2, 1, 4, 2>, S<4, 1, 16, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 2>, S<2, 1, 4, 2>, S<4, 1, 16, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 2>, S<0, 1, 2, 3, 4, 5>, 5, 1>, + DeviceBatchedGemmMultipleD_Dl< Col, Row, Empty_Tuple, Row, F16, F16, F32, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmMNPadding, 64, 64, 64, 8, 2, 4, 4, 1, S<4, 2>, S<8, 1>, S<2, 1, 4, 2>, S<4, 1, 16, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 2>, S<2, 1, 4, 2>, S<4, 1, 16, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 2>, S<0, 1, 2, 3, 4, 5>, 5, 1>, + // MPerBlock=16, NPerBlock=64 + DeviceBatchedGemmMultipleD_Dl< Col, Row, Empty_Tuple, Row, F16, F16, F32, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmMNPadding, 64, 16, 64, 16, 2, 1, 4, 1, S<2, 4>, S<2, 4>, S<1, 1, 4, 2>, S<16, 1, 4, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 2>, S<4, 1, 4, 2>, S<4, 1, 16, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 2>, S<0, 1, 2, 3, 4, 5>, 5, 1>, + // MPerBlock=64, NPerBlock=16 + DeviceBatchedGemmMultipleD_Dl< Col, Row, Empty_Tuple, Row, F16, F16, F32, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmMNPadding, 64, 64, 16, 16, 2, 4, 1, 1, S<4, 2>, S<4, 2>, S<4, 1, 4, 2>, S<4, 1, 16, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 2>, S<1, 1, 4, 2>, S<16, 1, 4, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 2>, S<0, 1, 2, 3, 4, 5>, 5, 1>, + // MPerBlock=8, NPerBlock=64 + DeviceBatchedGemmMultipleD_Dl< Col, Row, Empty_Tuple, Row, F16, F16, F32, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmMNPadding, 64, 8, 64, 32, 2, 1, 2, 1, S<4, 1>, S<8, 2>, S<1, 1, 4, 2>, S<32, 1, 2, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 2>, S<8, 1, 4, 2>, S<4, 1, 16, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 2>, S<0, 1, 2, 3, 4, 5>, 5, 1>, + DeviceBatchedGemmMultipleD_Dl< Col, Row, Empty_Tuple, Row, F16, F16, F32, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmMNPadding, 64, 8, 64, 32, 2, 1, 2, 1, S<2, 2>, S<8, 2>, S<1, 1, 4, 2>, S<32, 1, 2, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 2>, S<8, 1, 4, 2>, S<4, 1, 16, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 2>, S<0, 1, 2, 3, 4, 5>, 5, 1>, + // MPerBlock=64, NPerBlock=8 + DeviceBatchedGemmMultipleD_Dl< Col, Row, Empty_Tuple, Row, F16, F16, F32, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmMNPadding, 64, 64, 8, 32, 2, 2, 1, 1, S<8, 2>, S<4, 1>, S<8, 1, 4, 2>, S<4, 1, 16, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 2>, S<1, 1, 4, 2>, S<32, 1, 2, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 2>, S<0, 1, 2, 3, 4, 5>, 5, 1>, + DeviceBatchedGemmMultipleD_Dl< Col, Row, Empty_Tuple, Row, F16, F16, F32, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmMNPadding, 64, 64, 8, 32, 2, 2, 1, 1, S<8, 2>, S<2, 2>, S<8, 1, 4, 2>, S<4, 1, 16, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 2>, S<1, 1, 4, 2>, S<32, 1, 2, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 2>, S<0, 1, 2, 3, 4, 5>, 5, 1>, + // MPerBlock=8, NPerBlock=8 + DeviceBatchedGemmMultipleD_Dl< Col, Row, Empty_Tuple, Row, F16, F16, F32, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmMNPadding, 8, 8, 8, 4, 2, 2, 1, 1, S<2, 1>, S<4, 1>, S<1, 1, 4, 2>, S<4, 1, 2, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 2>, S<1, 1, 4, 2>, S<4, 1, 2, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 2>, S<0, 1, 2, 3, 4, 5>, 5, 1>, + DeviceBatchedGemmMultipleD_Dl< Col, Row, Empty_Tuple, Row, F16, F16, F32, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmMNPadding, 8, 8, 8, 4, 2, 2, 1, 1, S<1, 2>, S<1, 4>, S<1, 1, 4, 2>, S<4, 1, 2, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 2>, S<1, 1, 4, 2>, S<4, 1, 2, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 2>, S<0, 1, 2, 3, 4, 5>, 5, 1> + // clang-format on + >; + +void add_device_batched_gemm_multi_d_dl_f16_f16_f16_gkm_gkn_gmn_irregular_instances( + std::vector>>& instances) +{ + add_device_operation_instances( + instances, device_batched_gemm_multi_d_dl_f16_f16_f16_gkm_gkn_gmn_irregular_instances{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/batched_gemm_multi_d/device_batched_gemm_multi_d_dl_f16_f16_f16_gkm_gnk_gmn_instance.cpp b/library/src/tensor_operation_instance/gpu/batched_gemm_multi_d/device_batched_gemm_multi_d_dl_f16_f16_f16_gkm_gnk_gmn_instance.cpp new file mode 100644 index 0000000000000000000000000000000000000000..75042280f04c16982516c45d92ad46d40d98e91a --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/batched_gemm_multi_d/device_batched_gemm_multi_d_dl_f16_f16_f16_gkm_gnk_gmn_instance.cpp @@ -0,0 +1,95 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/device_batched_gemm_multi_d.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_batched_gemm_multiple_d_dl.hpp" +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +using F16 = ck::half_t; +using F32 = float; + +using Row = ck::tensor_layout::gemm::RowMajor; +using Col = ck::tensor_layout::gemm::ColumnMajor; + +template +using S = ck::Sequence; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; +using Empty_Tuple = ck::Tuple<>; + +static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default; + +// Compilation parameters for a[k, m] * b[n, k] = c[m, n] +using device_batched_gemm_multi_d_dl_f16_f16_f16_gkm_gnk_gmn_instances = std::tuple< + // clang-format off + // ##########################| ALayout| BLayout| DsLayout| CLayout| AData| BData| AccData| DsData| CData| A| B| C| GEMM| Block| MPer| NPer| K0Per| K1| M1Per| N1Per| KPer| M11N11Thread| M11N11Thread| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| CThreadTransfer| CThreadTransfer| CThreadTransfer| + // ##########################| | | | | Type| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise| Specialization| Size| Block| Block| Block| | ThreadM111| ThreadN111| Thread| ClusterM110Xs| ClusterN110Xs| ThreadSliceLengths| ThreadClusterLengths| ThreadCluster| SrcAccess| SrcVectorTensor| SrcVectorTensor| DstVectorTensor| ThreadSliceLengths| ThreadClusterLengths| ThreadCluster| SrcAccess| SrcVectorTensor| SrcVectorTensor| DstVectorTensor| SrcDstAccess| SrcDstVectorDim| DstScalarPerVector| + // ##########################| | | | | | | | | | Operation| Operation| Operation| | | | | | | | | | | | K0_M0_M1_K1| K0_M0_M1_K1| ArrangeOrder| Order| Lengths_K0_M0_M1_K1| ContiguousDimOrder| Lengths_K0_M0_M1_K1| K0_M0_M1_K1| K0_M0_M1_K1| ArrangeOrder| Order| Lengths_K0_M0_M1_K1| ContiguousDimOrder| Lengths_K0_M0_M1_K1| Order| | | + // ##########################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + // MPerBlock=128, NPerBlock=128 + DeviceBatchedGemmMultipleD_Dl< Col, Col, Empty_Tuple, Row, F16, F16, F32, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 128, 128, 16, 2, 4, 4, 1, S<8, 2>, S<8, 2>, S<2, 1, 4, 2>, S<8, 1, 32, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 2>, S<8, 1, 1, 2>, S<2, 1, 128, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<0, 1, 2, 3, 4, 5>, 5, 4>, + DeviceBatchedGemmMultipleD_Dl< Col, Col, Empty_Tuple, Row, F16, F16, F32, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 128, 128, 16, 2, 4, 4, 1, S<4, 4>, S<4, 4>, S<2, 1, 4, 2>, S<8, 1, 32, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 2>, S<8, 1, 1, 2>, S<2, 1, 128, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<0, 1, 2, 3, 4, 5>, 5, 4>, + DeviceBatchedGemmMultipleD_Dl< Col, Col, Empty_Tuple, Row, F16, F16, F32, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 128, 128, 16, 2, 4, 4, 1, S<2, 8>, S<2, 8>, S<2, 1, 4, 2>, S<8, 1, 32, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 2>, S<8, 1, 1, 2>, S<2, 1, 128, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<0, 1, 2, 3, 4, 5>, 5, 4>, + // MPerBlock=128, NPerBlock=64 + DeviceBatchedGemmMultipleD_Dl< Col, Col, Empty_Tuple, Row, F16, F16, F32, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmDefault, 128, 128, 64, 16, 2, 4, 4, 1, S<8, 2>, S<4, 2>, S<2, 1, 8, 2>, S<8, 1, 16, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 2>, S<8, 1, 2, 2>, S<2, 1, 32, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<0, 1, 2, 3, 4, 5>, 5, 4>, + DeviceBatchedGemmMultipleD_Dl< Col, Col, Empty_Tuple, Row, F16, F16, F32, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmDefault, 128, 128, 64, 16, 2, 4, 4, 1, S<2, 8>, S<2, 4>, S<2, 1, 8, 2>, S<8, 1, 16, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 2>, S<8, 1, 2, 2>, S<2, 1, 32, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<0, 1, 2, 3, 4, 5>, 5, 4>, + // MPerBlock=64, NPerBlock=128 + DeviceBatchedGemmMultipleD_Dl< Col, Col, Empty_Tuple, Row, F16, F16, F32, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmDefault, 128, 64, 128, 16, 2, 4, 4, 1, S<4, 2>, S<8, 2>, S<2, 1, 8, 2>, S<8, 1, 8, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 2>, S<8, 1, 2, 2>, S<2, 1, 64, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<0, 1, 2, 3, 4, 5>, 5, 4>, + DeviceBatchedGemmMultipleD_Dl< Col, Col, Empty_Tuple, Row, F16, F16, F32, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmDefault, 128, 64, 128, 16, 2, 4, 4, 1, S<2, 4>, S<2, 8>, S<2, 1, 8, 2>, S<8, 1, 8, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 2>, S<8, 1, 2, 2>, S<2, 1, 64, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<0, 1, 2, 3, 4, 5>, 5, 4>, + // MPerBlock=64, NPerBlock=64 + DeviceBatchedGemmMultipleD_Dl< Col, Col, Empty_Tuple, Row, F16, F16, F32, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmDefault, 64, 64, 64, 8, 2, 4, 4, 1, S<4, 2>, S<4, 2>, S<2, 1, 4, 2>, S<4, 1, 16, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 2>, S<4, 1, 2, 2>, S<2, 1, 32, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<0, 1, 2, 3, 4, 5>, 5, 4>, + DeviceBatchedGemmMultipleD_Dl< Col, Col, Empty_Tuple, Row, F16, F16, F32, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmDefault, 64, 64, 64, 8, 2, 4, 4, 1, S<2, 4>, S<2, 4>, S<2, 1, 4, 2>, S<4, 1, 16, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 2>, S<4, 1, 2, 2>, S<2, 1, 32, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<0, 1, 2, 3, 4, 5>, 5, 4>, + DeviceBatchedGemmMultipleD_Dl< Col, Col, Empty_Tuple, Row, F16, F16, F32, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmDefault, 64, 64, 64, 8, 2, 4, 4, 1, S<8, 1>, S<4, 2>, S<2, 1, 4, 2>, S<4, 1, 16, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 2>, S<4, 1, 2, 2>, S<2, 1, 32, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<0, 1, 2, 3, 4, 5>, 5, 4>, + DeviceBatchedGemmMultipleD_Dl< Col, Col, Empty_Tuple, Row, F16, F16, F32, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmDefault, 64, 64, 64, 8, 2, 4, 4, 1, S<4, 2>, S<8, 1>, S<2, 1, 4, 2>, S<4, 1, 16, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 2>, S<4, 1, 2, 2>, S<2, 1, 32, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<0, 1, 2, 3, 4, 5>, 5, 4>, + // MPerBlock=16, NPerBlock=64 + DeviceBatchedGemmMultipleD_Dl< Col, Col, Empty_Tuple, Row, F16, F16, F32, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmDefault, 64, 16, 64, 16, 2, 1, 4, 1, S<4, 2>, S<4, 2>, S<1, 1, 4, 2>, S<16, 1, 4, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 2>, S<4, 1, 4, 2>, S<4, 1, 16, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<0, 1, 2, 3, 4, 5>, 5, 4>, + // MPerBlock=64, NPerBlock=16 + DeviceBatchedGemmMultipleD_Dl< Col, Col, Empty_Tuple, Row, F16, F16, F32, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmDefault, 64, 64, 16, 16, 2, 4, 1, 1, S<4, 2>, S<4, 2>, S<4, 1, 4, 2>, S<4, 1, 16, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 2>, S<4, 1, 1, 2>, S<4, 1, 16, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<0, 1, 2, 3, 4, 5>, 5, 1>, + // MPerBlock=16, NPerBlock=16 + DeviceBatchedGemmMultipleD_Dl< Col, Col, Empty_Tuple, Row, F16, F16, F32, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmDefault, 16, 16, 16, 16, 2, 2, 2, 1, S<2, 2>, S<2, 2>, S<4, 1, 4, 2>, S<4, 1, 4, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 2>, S<4, 1, 4, 2>, S<4, 1, 4, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<0, 1, 2, 3, 4, 5>, 5, 2>, + DeviceBatchedGemmMultipleD_Dl< Col, Col, Empty_Tuple, Row, F16, F16, F32, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmDefault, 16, 16, 16, 16, 2, 2, 2, 1, S<1, 4>, S<1, 4>, S<4, 1, 4, 2>, S<4, 1, 4, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 2>, S<4, 1, 4, 2>, S<4, 1, 4, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<0, 1, 2, 3, 4, 5>, 5, 2>, + // MPerBlock=8, NPerBlock=64 + DeviceBatchedGemmMultipleD_Dl< Col, Col, Empty_Tuple, Row, F16, F16, F32, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmDefault, 64, 8, 64, 32, 2, 1, 2, 1, S<4, 1>, S<8, 2>, S<1, 1, 4, 2>, S<32, 1, 2, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 2>, S<8, 1, 4, 2>, S<4, 1, 16, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<0, 1, 2, 3, 4, 5>, 5, 2>, + DeviceBatchedGemmMultipleD_Dl< Col, Col, Empty_Tuple, Row, F16, F16, F32, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmDefault, 64, 8, 64, 32, 2, 1, 2, 1, S<2, 2>, S<8, 2>, S<1, 1, 4, 2>, S<32, 1, 2, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 2>, S<8, 1, 4, 2>, S<4, 1, 16, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<0, 1, 2, 3, 4, 5>, 5, 2>, + // MPerBlock=64, NPerBlock=8 + DeviceBatchedGemmMultipleD_Dl< Col, Col, Empty_Tuple, Row, F16, F16, F32, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmDefault, 64, 64, 8, 32, 2, 2, 1, 1, S<8, 2>, S<4, 1>, S<8, 1, 4, 2>, S<4, 1, 16, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 2>, S<4, 1, 1, 2>, S<8, 1, 8, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<0, 1, 2, 3, 4, 5>, 5, 1>, + DeviceBatchedGemmMultipleD_Dl< Col, Col, Empty_Tuple, Row, F16, F16, F32, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmDefault, 64, 64, 8, 32, 2, 2, 1, 1, S<8, 2>, S<2, 2>, S<8, 1, 4, 2>, S<4, 1, 16, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 2>, S<4, 1, 1, 2>, S<8, 1, 8, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<0, 1, 2, 3, 4, 5>, 5, 1>, + // MPerBlock=8, NPerBlock=8 + DeviceBatchedGemmMultipleD_Dl< Col, Col, Empty_Tuple, Row, F16, F16, F32, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmDefault, 8, 8, 8, 4, 2, 1, 2, 1, S<4, 1>, S<2, 1>, S<1, 1, 4, 2>, S<4, 1, 2, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 2>, S<4, 1, 1, 2>, S<1, 1, 8, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<0, 1, 2, 3, 4, 5>, 5, 2>, + DeviceBatchedGemmMultipleD_Dl< Col, Col, Empty_Tuple, Row, F16, F16, F32, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmDefault, 8, 8, 8, 4, 2, 1, 2, 1, S<1, 4>, S<1, 2>, S<1, 1, 4, 2>, S<4, 1, 2, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 2>, S<4, 1, 1, 2>, S<1, 1, 8, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<0, 1, 2, 3, 4, 5>, 5, 2>, + DeviceBatchedGemmMultipleD_Dl< Col, Col, Empty_Tuple, Row, F16, F16, F32, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmDefault, 8, 8, 8, 4, 2, 2, 1, 1, S<2, 1>, S<4, 1>, S<1, 1, 4, 2>, S<4, 1, 2, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 2>, S<4, 1, 1, 2>, S<1, 1, 8, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<0, 1, 2, 3, 4, 5>, 5, 1>, + DeviceBatchedGemmMultipleD_Dl< Col, Col, Empty_Tuple, Row, F16, F16, F32, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmDefault, 8, 8, 8, 4, 2, 2, 1, 1, S<1, 2>, S<1, 4>, S<1, 1, 4, 2>, S<4, 1, 2, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 2>, S<4, 1, 1, 2>, S<1, 1, 8, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<0, 1, 2, 3, 4, 5>, 5, 1> + // clang-format on + >; + +void add_device_batched_gemm_multi_d_dl_f16_f16_f16_gkm_gnk_gmn_instances( + std::vector>>& instances) +{ + add_device_operation_instances( + instances, device_batched_gemm_multi_d_dl_f16_f16_f16_gkm_gnk_gmn_instances{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/batched_gemm_multi_d/device_batched_gemm_multi_d_dl_f16_f16_f16_gkm_gnk_gmn_irregular_instance.cpp b/library/src/tensor_operation_instance/gpu/batched_gemm_multi_d/device_batched_gemm_multi_d_dl_f16_f16_f16_gkm_gnk_gmn_irregular_instance.cpp new file mode 100644 index 0000000000000000000000000000000000000000..c8e1acfd4f8a4a00f4e727657a5a7df327c57527 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/batched_gemm_multi_d/device_batched_gemm_multi_d_dl_f16_f16_f16_gkm_gnk_gmn_irregular_instance.cpp @@ -0,0 +1,83 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/device_batched_gemm_multi_d.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_batched_gemm_multiple_d_dl.hpp" +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +using F16 = ck::half_t; +using F32 = float; + +using Row = ck::tensor_layout::gemm::RowMajor; +using Col = ck::tensor_layout::gemm::ColumnMajor; + +template +using S = ck::Sequence; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; +using Empty_Tuple = ck::Tuple<>; + +static constexpr auto GemmMNPadding = ck::tensor_operation::device::GemmSpecialization::MNPadding; + +// Compilation parameters for a[k, m] * b[n, k] = c[m, n] +using device_batched_gemm_multi_d_dl_f16_f16_f16_gkm_gnk_gmn_irregular_instances = std::tuple< + // clang-format off + // ##########################| ALayout| BLayout| DsLayout| CLayout| AData| BData| AccData| DsData| CData| A| B| C| GEMM| Block| MPer| NPer| K0Per| K1| M1Per| N1Per| KPer| M11N11Thread| M11N11Thread| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| CThreadTransfer| CThreadTransfer| CThreadTransfer| + // ##########################| | | | | Type| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise| Specialization| Size| Block| Block| Block| | ThreadM111| ThreadN111| Thread| ClusterM110Xs| ClusterN110Xs| ThreadSliceLengths| ThreadClusterLengths| ThreadCluster| SrcAccess| SrcVectorTensor| SrcVectorTensor| DstVectorTensor| ThreadSliceLengths| ThreadClusterLengths| ThreadCluster| SrcAccess| SrcVectorTensor| SrcVectorTensor| DstVectorTensor| SrcDstAccess| SrcDstVectorDim| DstScalarPerVector| + // ##########################| | | | | | | | | | Operation| Operation| Operation| | | | | | | | | | | | K0_M0_M1_K1| K0_M0_M1_K1| ArrangeOrder| Order| Lengths_K0_M0_M1_K1| ContiguousDimOrder| Lengths_K0_M0_M1_K1| K0_M0_M1_K1| K0_M0_M1_K1| ArrangeOrder| Order| Lengths_K0_M0_M1_K1| ContiguousDimOrder| Lengths_K0_M0_M1_K1| Order| | | + // ##########################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + // MPerBlock=128, NPerBlock=128 + DeviceBatchedGemmMultipleD_Dl< Col, Col, Empty_Tuple, Row, F16, F16, F32, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmMNPadding, 256, 128, 128, 16, 2, 4, 4, 1, S<8, 2>, S<8, 2>, S<2, 1, 4, 2>, S<8, 1, 32, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 2>, S<8, 1, 1, 2>, S<2, 1, 128, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<0, 1, 2, 3, 4, 5>, 5, 1>, + DeviceBatchedGemmMultipleD_Dl< Col, Col, Empty_Tuple, Row, F16, F16, F32, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmMNPadding, 256, 128, 128, 16, 2, 4, 4, 1, S<4, 4>, S<4, 4>, S<2, 1, 4, 2>, S<8, 1, 32, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 2>, S<8, 1, 1, 2>, S<2, 1, 128, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<0, 1, 2, 3, 4, 5>, 5, 1>, + DeviceBatchedGemmMultipleD_Dl< Col, Col, Empty_Tuple, Row, F16, F16, F32, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmMNPadding, 256, 128, 128, 16, 2, 4, 4, 1, S<2, 8>, S<2, 8>, S<2, 1, 4, 2>, S<8, 1, 32, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 2>, S<8, 1, 1, 2>, S<2, 1, 128, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<0, 1, 2, 3, 4, 5>, 5, 1>, + // MPerBlock=64, NPerBlock=64 + DeviceBatchedGemmMultipleD_Dl< Col, Col, Empty_Tuple, Row, F16, F16, F32, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmMNPadding, 64, 64, 64, 8, 2, 4, 4, 1, S<4, 2>, S<4, 2>, S<2, 1, 4, 2>, S<4, 1, 16, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 2>, S<4, 1, 2, 2>, S<2, 1, 32, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<0, 1, 2, 3, 4, 5>, 5, 1>, + DeviceBatchedGemmMultipleD_Dl< Col, Col, Empty_Tuple, Row, F16, F16, F32, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmMNPadding, 64, 64, 64, 8, 2, 4, 4, 1, S<2, 4>, S<2, 4>, S<2, 1, 4, 2>, S<4, 1, 16, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 2>, S<4, 1, 2, 2>, S<2, 1, 32, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<0, 1, 2, 3, 4, 5>, 5, 1>, + DeviceBatchedGemmMultipleD_Dl< Col, Col, Empty_Tuple, Row, F16, F16, F32, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmMNPadding, 64, 64, 64, 8, 2, 4, 4, 1, S<4, 2>, S<8, 1>, S<2, 1, 4, 2>, S<4, 1, 16, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 2>, S<4, 1, 2, 2>, S<2, 1, 32, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<0, 1, 2, 3, 4, 5>, 5, 1>, + // MPerBlock=16, NPerBlock=64 + DeviceBatchedGemmMultipleD_Dl< Col, Col, Empty_Tuple, Row, F16, F16, F32, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmMNPadding, 64, 16, 64, 16, 2, 1, 4, 1, S<2, 4>, S<2, 4>, S<1, 1, 4, 2>, S<16, 1, 4, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 2>, S<4, 1, 4, 2>, S<4, 1, 16, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<0, 1, 2, 3, 4, 5>, 5, 1>, + // MPerBlock=64, NPerBlock=16 + DeviceBatchedGemmMultipleD_Dl< Col, Col, Empty_Tuple, Row, F16, F16, F32, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmMNPadding, 64, 64, 16, 16, 2, 4, 1, 1, S<4, 2>, S<4, 2>, S<4, 1, 4, 2>, S<4, 1, 16, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 2>, S<4, 1, 1, 2>, S<4, 1, 16, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<0, 1, 2, 3, 4, 5>, 5, 1>, + // MPerBlock=8, NPerBlock=64 + DeviceBatchedGemmMultipleD_Dl< Col, Col, Empty_Tuple, Row, F16, F16, F32, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmMNPadding, 64, 8, 64, 32, 2, 1, 2, 1, S<4, 1>, S<8, 2>, S<1, 1, 4, 2>, S<32, 1, 2, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 2>, S<8, 1, 4, 2>, S<4, 1, 16, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<0, 1, 2, 3, 4, 5>, 5, 1>, + DeviceBatchedGemmMultipleD_Dl< Col, Col, Empty_Tuple, Row, F16, F16, F32, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmMNPadding, 64, 8, 64, 32, 2, 1, 2, 1, S<2, 2>, S<8, 2>, S<1, 1, 4, 2>, S<32, 1, 2, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 2>, S<8, 1, 4, 2>, S<4, 1, 16, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<0, 1, 2, 3, 4, 5>, 5, 1>, + // MPerBlock=64, NPerBlock=8 + DeviceBatchedGemmMultipleD_Dl< Col, Col, Empty_Tuple, Row, F16, F16, F32, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmMNPadding, 64, 64, 8, 32, 2, 2, 1, 1, S<8, 2>, S<4, 1>, S<8, 1, 4, 2>, S<4, 1, 16, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 2>, S<4, 1, 1, 2>, S<8, 1, 8, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<0, 1, 2, 3, 4, 5>, 5, 1>, + DeviceBatchedGemmMultipleD_Dl< Col, Col, Empty_Tuple, Row, F16, F16, F32, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmMNPadding, 64, 64, 8, 32, 2, 2, 1, 1, S<8, 2>, S<2, 2>, S<8, 1, 4, 2>, S<4, 1, 16, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 2>, S<4, 1, 1, 2>, S<8, 1, 8, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<0, 1, 2, 3, 4, 5>, 5, 1>, + // MPerBlock=8, NPerBlock=8 + DeviceBatchedGemmMultipleD_Dl< Col, Col, Empty_Tuple, Row, F16, F16, F32, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmMNPadding, 8, 8, 8, 4, 2, 2, 1, 1, S<2, 1>, S<4, 1>, S<1, 1, 4, 2>, S<4, 1, 2, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 2>, S<4, 1, 1, 2>, S<1, 1, 8, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<0, 1, 2, 3, 4, 5>, 5, 1>, + DeviceBatchedGemmMultipleD_Dl< Col, Col, Empty_Tuple, Row, F16, F16, F32, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmMNPadding, 8, 8, 8, 4, 2, 2, 1, 1, S<1, 2>, S<1, 4>, S<1, 1, 4, 2>, S<4, 1, 2, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 2>, S<4, 1, 1, 2>, S<1, 1, 8, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<0, 1, 2, 3, 4, 5>, 5, 1> + // clang-format on + >; + +void add_device_batched_gemm_multi_d_dl_f16_f16_f16_gkm_gnk_gmn_irregular_instances( + std::vector>>& instances) +{ + add_device_operation_instances( + instances, device_batched_gemm_multi_d_dl_f16_f16_f16_gkm_gnk_gmn_irregular_instances{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/batched_gemm_multi_d/device_batched_gemm_multi_d_dl_f16_f16_f16_gmk_gkn_gmn_instance.cpp b/library/src/tensor_operation_instance/gpu/batched_gemm_multi_d/device_batched_gemm_multi_d_dl_f16_f16_f16_gmk_gkn_gmn_instance.cpp new file mode 100644 index 0000000000000000000000000000000000000000..f1ed7a69423a1abfb92b85368a921472db5fc5cf --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/batched_gemm_multi_d/device_batched_gemm_multi_d_dl_f16_f16_f16_gmk_gkn_gmn_instance.cpp @@ -0,0 +1,95 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/device_batched_gemm_multi_d.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_batched_gemm_multiple_d_dl.hpp" +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +using F16 = ck::half_t; +using F32 = float; + +using Row = ck::tensor_layout::gemm::RowMajor; +using Col = ck::tensor_layout::gemm::ColumnMajor; + +template +using S = ck::Sequence; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; +using Empty_Tuple = ck::Tuple<>; + +static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default; + +// Compilation parameters for a[m, k] * b[k, n] = c[m, n] +using device_batched_gemm_multi_d_dl_f16_f16_f16_gmk_gkn_gmn_instances = std::tuple< + // clang-format off + // ##########################| ALayout| BLayout| DsLayout| CLayout| AData| BData| AccData| DsData| CData| A| B| C| GEMM| Block| MPer| NPer| K0Per| K1| M1Per| N1Per| KPer| M11N11Thread| M11N11Thread| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| CThreadTransfer| CThreadTransfer| CThreadTransfer| + // ##########################| | | | | Type| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise| Specialization| Size| Block| Block| Block| | ThreadM111| ThreadN111| Thread| ClusterM110Xs| ClusterN110Xs| ThreadSliceLengths| ThreadClusterLengths| ThreadCluster| SrcAccess| SrcVectorTensor| SrcVectorTensor| DstVectorTensor| ThreadSliceLengths| ThreadClusterLengths| ThreadCluster| SrcAccess| SrcVectorTensor| SrcVectorTensor| DstVectorTensor| SrcDstAccess| SrcDstVectorDim| DstScalarPerVector| + // ##########################| | | | | | | | | | Operation| Operation| Operation| | | | | | | | | | | | K0_M0_M1_K1| K0_M0_M1_K1| ArrangeOrder| Order| Lengths_K0_M0_M1_K1| ContiguousDimOrder| Lengths_K0_M0_M1_K1| K0_M0_M1_K1| K0_M0_M1_K1| ArrangeOrder| Order| Lengths_K0_M0_M1_K1| ContiguousDimOrder| Lengths_K0_M0_M1_K1| Order| | | + // ##########################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + // MPerBlock=128, NPerBlock=128 + DeviceBatchedGemmMultipleD_Dl< Row, Row, Empty_Tuple, Row, F16, F16, F32, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 128, 128, 16, 2, 4, 4, 1, S<8, 2>, S<8, 2>, S<8, 1, 1, 2>, S<2, 1, 128, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<2, 1, 4, 2>, S<8, 1, 32, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 2>, S<0, 1, 2, 3, 4, 5>, 5, 4>, + DeviceBatchedGemmMultipleD_Dl< Row, Row, Empty_Tuple, Row, F16, F16, F32, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 128, 128, 16, 2, 4, 4, 1, S<4, 4>, S<4, 4>, S<8, 1, 1, 2>, S<2, 1, 128, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<2, 1, 4, 2>, S<8, 1, 32, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 2>, S<0, 1, 2, 3, 4, 5>, 5, 4>, + DeviceBatchedGemmMultipleD_Dl< Row, Row, Empty_Tuple, Row, F16, F16, F32, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 128, 128, 16, 2, 4, 4, 1, S<2, 8>, S<2, 8>, S<8, 1, 1, 2>, S<2, 1, 128, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<2, 1, 4, 2>, S<8, 1, 32, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 2>, S<0, 1, 2, 3, 4, 5>, 5, 4>, + // MPerBlock=128, NPerBlock=64 + DeviceBatchedGemmMultipleD_Dl< Row, Row, Empty_Tuple, Row, F16, F16, F32, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmDefault, 128, 128, 64, 16, 2, 4, 4, 1, S<8, 2>, S<4, 2>, S<8, 1, 2, 2>, S<2, 1, 64, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<2, 1, 8, 2>, S<8, 1, 8, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 2>, S<0, 1, 2, 3, 4, 5>, 5, 4>, + DeviceBatchedGemmMultipleD_Dl< Row, Row, Empty_Tuple, Row, F16, F16, F32, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmDefault, 128, 128, 64, 16, 2, 4, 4, 1, S<2, 8>, S<2, 4>, S<8, 1, 2, 2>, S<2, 1, 64, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<2, 1, 8, 2>, S<8, 1, 8, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 2>, S<0, 1, 2, 3, 4, 5>, 5, 4>, + // MPerBlock=64, NPerBlock=128 + DeviceBatchedGemmMultipleD_Dl< Row, Row, Empty_Tuple, Row, F16, F16, F32, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmDefault, 128, 64, 128, 16, 2, 4, 4, 1, S<4, 2>, S<8, 2>, S<8, 1, 2, 2>, S<2, 1, 32, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<2, 1, 8, 2>, S<8, 1, 16, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 2>, S<0, 1, 2, 3, 4, 5>, 5, 4>, + DeviceBatchedGemmMultipleD_Dl< Row, Row, Empty_Tuple, Row, F16, F16, F32, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmDefault, 128, 64, 128, 16, 2, 4, 4, 1, S<2, 4>, S<2, 8>, S<8, 1, 2, 2>, S<2, 1, 32, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<2, 1, 8, 2>, S<8, 1, 16, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 2>, S<0, 1, 2, 3, 4, 5>, 5, 4>, + // MPerBlock=64, NPerBlock=64 + DeviceBatchedGemmMultipleD_Dl< Row, Row, Empty_Tuple, Row, F16, F16, F32, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmDefault, 64, 64, 64, 8, 2, 4, 4, 1, S<4, 2>, S<4, 2>, S<4, 1, 2, 2>, S<2, 1, 32, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<2, 1, 4, 2>, S<4, 1, 16, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 2>, S<0, 1, 2, 3, 4, 5>, 5, 4>, + DeviceBatchedGemmMultipleD_Dl< Row, Row, Empty_Tuple, Row, F16, F16, F32, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmDefault, 64, 64, 64, 8, 2, 4, 4, 1, S<2, 4>, S<2, 4>, S<4, 1, 2, 2>, S<2, 1, 32, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<2, 1, 4, 2>, S<4, 1, 16, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 2>, S<0, 1, 2, 3, 4, 5>, 5, 4>, + DeviceBatchedGemmMultipleD_Dl< Row, Row, Empty_Tuple, Row, F16, F16, F32, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmDefault, 64, 64, 64, 8, 2, 4, 4, 1, S<8, 1>, S<4, 2>, S<4, 1, 2, 2>, S<2, 1, 32, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<2, 1, 4, 2>, S<4, 1, 16, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 2>, S<0, 1, 2, 3, 4, 5>, 5, 4>, + DeviceBatchedGemmMultipleD_Dl< Row, Row, Empty_Tuple, Row, F16, F16, F32, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmDefault, 64, 64, 64, 8, 2, 4, 4, 1, S<4, 2>, S<8, 1>, S<4, 1, 2, 2>, S<2, 1, 32, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<2, 1, 4, 2>, S<4, 1, 16, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 2>, S<0, 1, 2, 3, 4, 5>, 5, 4>, + // MPerBlock=16, NPerBlock=64 + DeviceBatchedGemmMultipleD_Dl< Row, Row, Empty_Tuple, Row, F16, F16, F32, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmDefault, 64, 16, 64, 16, 2, 1, 4, 1, S<4, 2>, S<4, 2>, S<4, 1, 1, 2>, S<4, 1, 16, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<4, 1, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 1, 2>, S<4, 1, 4, 2>, S<4, 1, 16, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 2>, S<0, 1, 2, 3, 4, 5>, 5, 4>, + // MPerBlock=64, NPerBlock=16 + DeviceBatchedGemmMultipleD_Dl< Row, Row, Empty_Tuple, Row, F16, F16, F32, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmDefault, 64, 64, 16, 16, 2, 4, 1, 1, S<4, 2>, S<4, 2>, S<4, 1, 4, 2>, S<4, 1, 16, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<4, 1, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 1, 2>, S<1, 1, 4, 2>, S<16, 1, 4, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 2>, S<0, 1, 2, 3, 4, 5>, 5, 1>, + // MPerBlock=16, NPerBlock=16 + DeviceBatchedGemmMultipleD_Dl< Row, Row, Empty_Tuple, Row, F16, F16, F32, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmDefault, 16, 16, 16, 16, 2, 2, 2, 1, S<2, 2>, S<2, 2>, S<4, 1, 4, 2>, S<4, 1, 4, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<4, 1, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 1, 2>, S<4, 1, 4, 2>, S<4, 1, 4, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 2>, S<0, 1, 2, 3, 4, 5>, 5, 2>, + DeviceBatchedGemmMultipleD_Dl< Row, Row, Empty_Tuple, Row, F16, F16, F32, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmDefault, 16, 16, 16, 16, 2, 2, 2, 1, S<1, 4>, S<1, 4>, S<4, 1, 4, 2>, S<4, 1, 4, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<4, 1, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 1, 2>, S<4, 1, 4, 2>, S<4, 1, 4, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 2>, S<0, 1, 2, 3, 4, 5>, 5, 2>, + // MPerBlock=8, NPerBlock=64 + DeviceBatchedGemmMultipleD_Dl< Row, Row, Empty_Tuple, Row, F16, F16, F32, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmDefault, 64, 8, 64, 32, 2, 1, 2, 1, S<4, 1>, S<8, 2>, S<4, 1, 1, 2>, S<8, 1, 8, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<4, 1, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 1, 2>, S<8, 1, 4, 2>, S<4, 1, 16, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 2>, S<0, 1, 2, 3, 4, 5>, 5, 2>, + DeviceBatchedGemmMultipleD_Dl< Row, Row, Empty_Tuple, Row, F16, F16, F32, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmDefault, 64, 8, 64, 32, 2, 1, 2, 1, S<2, 2>, S<8, 2>, S<4, 1, 1, 2>, S<8, 1, 8, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<4, 1, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 1, 2>, S<8, 1, 4, 2>, S<4, 1, 16, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 2>, S<0, 1, 2, 3, 4, 5>, 5, 2>, + // MPerBlock=64, NPerBlock=8 + DeviceBatchedGemmMultipleD_Dl< Row, Row, Empty_Tuple, Row, F16, F16, F32, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmDefault, 64, 64, 8, 32, 2, 2, 1, 1, S<8, 2>, S<4, 1>, S<8, 1, 4, 2>, S<4, 1, 16, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<4, 1, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 1, 2>, S<1, 1, 4, 2>, S<32, 1, 2, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 2>, S<0, 1, 2, 3, 4, 5>, 5, 1>, + DeviceBatchedGemmMultipleD_Dl< Row, Row, Empty_Tuple, Row, F16, F16, F32, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmDefault, 64, 64, 8, 32, 2, 2, 1, 1, S<8, 2>, S<2, 2>, S<8, 1, 4, 2>, S<4, 1, 16, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<4, 1, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 1, 2>, S<1, 1, 4, 2>, S<32, 1, 2, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 2>, S<0, 1, 2, 3, 4, 5>, 5, 1>, + // MPerBlock=8, NPerBlock=8 + DeviceBatchedGemmMultipleD_Dl< Row, Row, Empty_Tuple, Row, F16, F16, F32, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmDefault, 8, 8, 8, 4, 2, 1, 2, 1, S<4, 1>, S<2, 1>, S<4, 1, 1, 2>, S<1, 1, 8, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<4, 1, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 1, 2>, S<1, 1, 4, 2>, S<4, 1, 2, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 2>, S<0, 1, 2, 3, 4, 5>, 5, 2>, + DeviceBatchedGemmMultipleD_Dl< Row, Row, Empty_Tuple, Row, F16, F16, F32, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmDefault, 8, 8, 8, 4, 2, 1, 2, 1, S<1, 4>, S<1, 2>, S<4, 1, 1, 2>, S<1, 1, 8, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<4, 1, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 1, 2>, S<1, 1, 4, 2>, S<4, 1, 2, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 2>, S<0, 1, 2, 3, 4, 5>, 5, 2>, + DeviceBatchedGemmMultipleD_Dl< Row, Row, Empty_Tuple, Row, F16, F16, F32, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmDefault, 8, 8, 8, 4, 2, 2, 1, 1, S<2, 1>, S<4, 1>, S<4, 1, 1, 2>, S<1, 1, 8, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<4, 1, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 1, 2>, S<1, 1, 4, 2>, S<4, 1, 2, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 2>, S<0, 1, 2, 3, 4, 5>, 5, 1>, + DeviceBatchedGemmMultipleD_Dl< Row, Row, Empty_Tuple, Row, F16, F16, F32, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmDefault, 8, 8, 8, 4, 2, 2, 1, 1, S<1, 2>, S<1, 4>, S<4, 1, 1, 2>, S<1, 1, 8, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<4, 1, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 1, 2>, S<1, 1, 4, 2>, S<4, 1, 2, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 2>, S<0, 1, 2, 3, 4, 5>, 5, 1> + // clang-format on + >; + +void add_device_batched_gemm_multi_d_dl_f16_f16_f16_gmk_gkn_gmn_instances( + std::vector>>& instances) +{ + add_device_operation_instances( + instances, device_batched_gemm_multi_d_dl_f16_f16_f16_gmk_gkn_gmn_instances{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/batched_gemm_multi_d/device_batched_gemm_multi_d_dl_f16_f16_f16_gmk_gkn_gmn_irregular_instance.cpp b/library/src/tensor_operation_instance/gpu/batched_gemm_multi_d/device_batched_gemm_multi_d_dl_f16_f16_f16_gmk_gkn_gmn_irregular_instance.cpp new file mode 100644 index 0000000000000000000000000000000000000000..6e25517ed8600b8df740b14fa67b9ab9c60bdec3 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/batched_gemm_multi_d/device_batched_gemm_multi_d_dl_f16_f16_f16_gmk_gkn_gmn_irregular_instance.cpp @@ -0,0 +1,83 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/device_batched_gemm_multi_d.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_batched_gemm_multiple_d_dl.hpp" +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +using F16 = ck::half_t; +using F32 = float; + +using Row = ck::tensor_layout::gemm::RowMajor; +using Col = ck::tensor_layout::gemm::ColumnMajor; + +template +using S = ck::Sequence; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; +using Empty_Tuple = ck::Tuple<>; + +static constexpr auto GemmMNPadding = ck::tensor_operation::device::GemmSpecialization::MNPadding; + +// Compilation parameters for a[m, k] * b[k, n] = c[m, n] +using device_batched_gemm_multi_d_dl_f16_f16_f16_gmk_gkn_gmn_irregular_instances = std::tuple< + // clang-format off + // ##########################| ALayout| BLayout| DsLayout| CLayout| AData| BData| AccData| DsData| CData| A| B| C| GEMM| Block| MPer| NPer| K0Per| K1| M1Per| N1Per| KPer| M11N11Thread| M11N11Thread| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| CThreadTransfer| CThreadTransfer| CThreadTransfer| + // ##########################| | | | | Type| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise| Specialization| Size| Block| Block| Block| | ThreadM111| ThreadN111| Thread| ClusterM110Xs| ClusterN110Xs| ThreadSliceLengths| ThreadClusterLengths| ThreadCluster| SrcAccess| SrcVectorTensor| SrcVectorTensor| DstVectorTensor| ThreadSliceLengths| ThreadClusterLengths| ThreadCluster| SrcAccess| SrcVectorTensor| SrcVectorTensor| DstVectorTensor| SrcDstAccess| SrcDstVectorDim| DstScalarPerVector| + // ##########################| | | | | | | | | | Operation| Operation| Operation| | | | | | | | | | | | K0_M0_M1_K1| K0_M0_M1_K1| ArrangeOrder| Order| Lengths_K0_M0_M1_K1| ContiguousDimOrder| Lengths_K0_M0_M1_K1| K0_M0_M1_K1| K0_M0_M1_K1| ArrangeOrder| Order| Lengths_K0_M0_M1_K1| ContiguousDimOrder| Lengths_K0_M0_M1_K1| Order| | | + // ##########################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + // MPerBlock=128, NPerBlock=128 + DeviceBatchedGemmMultipleD_Dl< Row, Row, Empty_Tuple, Row, F16, F16, F32, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmMNPadding, 256, 128, 128, 16, 2, 4, 4, 1, S<8, 2>, S<8, 2>, S<8, 1, 1, 2>, S<2, 1, 128, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<2, 1, 4, 2>, S<8, 1, 32, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 2>, S<0, 1, 2, 3, 4, 5>, 5, 1>, + DeviceBatchedGemmMultipleD_Dl< Row, Row, Empty_Tuple, Row, F16, F16, F32, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmMNPadding, 256, 128, 128, 16, 2, 4, 4, 1, S<4, 4>, S<4, 4>, S<8, 1, 1, 2>, S<2, 1, 128, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<2, 1, 4, 2>, S<8, 1, 32, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 2>, S<0, 1, 2, 3, 4, 5>, 5, 1>, + DeviceBatchedGemmMultipleD_Dl< Row, Row, Empty_Tuple, Row, F16, F16, F32, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmMNPadding, 256, 128, 128, 16, 2, 4, 4, 1, S<2, 8>, S<2, 8>, S<8, 1, 1, 2>, S<2, 1, 128, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<2, 1, 4, 2>, S<8, 1, 32, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 2>, S<0, 1, 2, 3, 4, 5>, 5, 1>, + // MPerBlock=64, NPerBlock=64 + DeviceBatchedGemmMultipleD_Dl< Row, Row, Empty_Tuple, Row, F16, F16, F32, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmMNPadding, 64, 64, 64, 8, 2, 4, 4, 1, S<4, 2>, S<4, 2>, S<4, 1, 2, 2>, S<2, 1, 32, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<2, 1, 4, 2>, S<4, 1, 16, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 2>, S<0, 1, 2, 3, 4, 5>, 5, 1>, + DeviceBatchedGemmMultipleD_Dl< Row, Row, Empty_Tuple, Row, F16, F16, F32, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmMNPadding, 64, 64, 64, 8, 2, 4, 4, 1, S<2, 4>, S<2, 4>, S<4, 1, 2, 2>, S<2, 1, 32, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<2, 1, 4, 2>, S<4, 1, 16, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 2>, S<0, 1, 2, 3, 4, 5>, 5, 1>, + DeviceBatchedGemmMultipleD_Dl< Row, Row, Empty_Tuple, Row, F16, F16, F32, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmMNPadding, 64, 64, 64, 8, 2, 4, 4, 1, S<4, 2>, S<8, 1>, S<4, 1, 2, 2>, S<2, 1, 32, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<2, 1, 4, 2>, S<4, 1, 16, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 2>, S<0, 1, 2, 3, 4, 5>, 5, 1>, + // MPerBlock=16, NPerBlock=64 + DeviceBatchedGemmMultipleD_Dl< Row, Row, Empty_Tuple, Row, F16, F16, F32, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmMNPadding, 64, 16, 64, 16, 2, 1, 4, 1, S<2, 4>, S<2, 4>, S<4, 1, 1, 2>, S<4, 1, 16, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<4, 1, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 1, 2>, S<4, 1, 4, 2>, S<4, 1, 16, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 2>, S<0, 1, 2, 3, 4, 5>, 5, 1>, + // MPerBlock=64, NPerBlock=16 + DeviceBatchedGemmMultipleD_Dl< Row, Row, Empty_Tuple, Row, F16, F16, F32, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmMNPadding, 64, 64, 16, 16, 2, 4, 1, 1, S<4, 2>, S<4, 2>, S<4, 1, 4, 2>, S<4, 1, 16, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<4, 1, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 1, 2>, S<1, 1, 4, 2>, S<16, 1, 4, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 2>, S<0, 1, 2, 3, 4, 5>, 5, 1>, + // MPerBlock=8, NPerBlock=64 + DeviceBatchedGemmMultipleD_Dl< Row, Row, Empty_Tuple, Row, F16, F16, F32, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmMNPadding, 64, 8, 64, 32, 2, 1, 2, 1, S<4, 1>, S<8, 2>, S<4, 1, 1, 2>, S<8, 1, 8, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<4, 1, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 1, 2>, S<8, 1, 4, 2>, S<4, 1, 16, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 2>, S<0, 1, 2, 3, 4, 5>, 5, 1>, + DeviceBatchedGemmMultipleD_Dl< Row, Row, Empty_Tuple, Row, F16, F16, F32, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmMNPadding, 64, 8, 64, 32, 2, 1, 2, 1, S<2, 2>, S<8, 2>, S<4, 1, 1, 2>, S<8, 1, 8, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<4, 1, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 1, 2>, S<8, 1, 4, 2>, S<4, 1, 16, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 2>, S<0, 1, 2, 3, 4, 5>, 5, 1>, + // MPerBlock=64, NPerBlock=8 + DeviceBatchedGemmMultipleD_Dl< Row, Row, Empty_Tuple, Row, F16, F16, F32, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmMNPadding, 64, 64, 8, 32, 2, 2, 1, 1, S<8, 2>, S<4, 1>, S<8, 1, 4, 2>, S<4, 1, 16, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<4, 1, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 1, 2>, S<1, 1, 4, 2>, S<32, 1, 2, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 2>, S<0, 1, 2, 3, 4, 5>, 5, 1>, + DeviceBatchedGemmMultipleD_Dl< Row, Row, Empty_Tuple, Row, F16, F16, F32, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmMNPadding, 64, 64, 8, 32, 2, 2, 1, 1, S<8, 2>, S<2, 2>, S<8, 1, 4, 2>, S<4, 1, 16, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<4, 1, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 1, 2>, S<1, 1, 4, 2>, S<32, 1, 2, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 2>, S<0, 1, 2, 3, 4, 5>, 5, 1>, + // MPerBlock=8, NPerBlock=8 + DeviceBatchedGemmMultipleD_Dl< Row, Row, Empty_Tuple, Row, F16, F16, F32, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmMNPadding, 8, 8, 8, 4, 2, 2, 1, 1, S<2, 1>, S<4, 1>, S<4, 1, 1, 2>, S<1, 1, 8, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<4, 1, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 1, 2>, S<1, 1, 4, 2>, S<4, 1, 2, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 2>, S<0, 1, 2, 3, 4, 5>, 5, 1>, + DeviceBatchedGemmMultipleD_Dl< Row, Row, Empty_Tuple, Row, F16, F16, F32, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmMNPadding, 8, 8, 8, 4, 2, 2, 1, 1, S<1, 2>, S<1, 4>, S<4, 1, 1, 2>, S<1, 1, 8, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<4, 1, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 1, 2>, S<1, 1, 4, 2>, S<4, 1, 2, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 2>, S<0, 1, 2, 3, 4, 5>, 5, 1> + // clang-format on + >; + +void add_device_batched_gemm_multi_d_dl_f16_f16_f16_gmk_gkn_gmn_irregular_instances( + std::vector>>& instances) +{ + add_device_operation_instances( + instances, device_batched_gemm_multi_d_dl_f16_f16_f16_gmk_gkn_gmn_irregular_instances{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/batched_gemm_multi_d/device_batched_gemm_multi_d_dl_f16_f16_f16_gmk_gnk_gmn_instance.cpp b/library/src/tensor_operation_instance/gpu/batched_gemm_multi_d/device_batched_gemm_multi_d_dl_f16_f16_f16_gmk_gnk_gmn_instance.cpp new file mode 100644 index 0000000000000000000000000000000000000000..5a9fd16c83b9b7e44e718da96eb7612f9ac039be --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/batched_gemm_multi_d/device_batched_gemm_multi_d_dl_f16_f16_f16_gmk_gnk_gmn_instance.cpp @@ -0,0 +1,95 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/device_batched_gemm_multi_d.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_batched_gemm_multiple_d_dl.hpp" +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +using F16 = ck::half_t; +using F32 = float; + +using Row = ck::tensor_layout::gemm::RowMajor; +using Col = ck::tensor_layout::gemm::ColumnMajor; + +template +using S = ck::Sequence; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; +using Empty_Tuple = ck::Tuple<>; + +static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default; + +// Compilation parameters for a[m, k] * b[n, k] = c[m, n] +using device_batched_gemm_multi_d_dl_f16_f16_f16_gmk_gnk_gmn_instances = std::tuple< + // clang-format off + // ##########################| ALayout| BLayout| DsLayout| CLayout| AData| BData| AccData| DsData| CData| A| B| C| GEMM| Block| MPer| NPer| K0Per| K1| M1Per| N1Per| KPer| M11N11Thread| M11N11Thread| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| CThreadTransfer| CThreadTransfer| CThreadTransfer| + // ##########################| | | | | Type| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise| Specialization| Size| Block| Block| Block| | ThreadM111| ThreadN111| Thread| ClusterM110Xs| ClusterN110Xs| ThreadSliceLengths| ThreadClusterLengths| ThreadCluster| SrcAccess| SrcVectorTensor| SrcVectorTensor| DstVectorTensor| ThreadSliceLengths| ThreadClusterLengths| ThreadCluster| SrcAccess| SrcVectorTensor| SrcVectorTensor| DstVectorTensor| SrcDstAccess| SrcDstVectorDim| DstScalarPerVector| + // ##########################| | | | | | | | | | Operation| Operation| Operation| | | | | | | | | | | | K0_M0_M1_K1| K0_M0_M1_K1| ArrangeOrder| Order| Lengths_K0_M0_M1_K1| ContiguousDimOrder| Lengths_K0_M0_M1_K1| K0_M0_M1_K1| K0_M0_M1_K1| ArrangeOrder| Order| Lengths_K0_M0_M1_K1| ContiguousDimOrder| Lengths_K0_M0_M1_K1| Order| | | + // ##########################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + // MPerBlock=128, NPerBlock=128 + DeviceBatchedGemmMultipleD_Dl< Row, Col, Empty_Tuple, Row, F16, F16, F32, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 128, 128, 16, 2, 4, 4, 1, S<8, 2>, S<8, 2>, S<8, 1, 1, 2>, S<2, 1, 128, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<8, 1, 1, 2>, S<2, 1, 128, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<0, 1, 2, 3, 4, 5>, 5, 4>, + DeviceBatchedGemmMultipleD_Dl< Row, Col, Empty_Tuple, Row, F16, F16, F32, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 128, 128, 16, 2, 4, 4, 1, S<4, 4>, S<4, 4>, S<8, 1, 1, 2>, S<2, 1, 128, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<8, 1, 1, 2>, S<2, 1, 128, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<0, 1, 2, 3, 4, 5>, 5, 4>, + DeviceBatchedGemmMultipleD_Dl< Row, Col, Empty_Tuple, Row, F16, F16, F32, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 128, 128, 16, 2, 4, 4, 1, S<2, 8>, S<2, 8>, S<8, 1, 1, 2>, S<2, 1, 128, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<8, 1, 1, 2>, S<2, 1, 128, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<0, 1, 2, 3, 4, 5>, 5, 4>, + // // MPerBlock=128, NPerBlock=64 + DeviceBatchedGemmMultipleD_Dl< Row, Col, Empty_Tuple, Row, F16, F16, F32, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmDefault, 128, 128, 64, 16, 2, 4, 4, 1, S<8, 2>, S<4, 2>, S<8, 1, 2, 2>, S<2, 1, 64, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<8, 1, 2, 2>, S<2, 1, 32, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<0, 1, 2, 3, 4, 5>, 5, 4>, + DeviceBatchedGemmMultipleD_Dl< Row, Col, Empty_Tuple, Row, F16, F16, F32, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmDefault, 128, 128, 64, 16, 2, 4, 4, 1, S<2, 8>, S<2, 4>, S<8, 1, 2, 2>, S<2, 1, 64, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<8, 1, 2, 2>, S<2, 1, 32, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<0, 1, 2, 3, 4, 5>, 5, 4>, + // // MPerBlock=64, NPerBlock=128 + DeviceBatchedGemmMultipleD_Dl< Row, Col, Empty_Tuple, Row, F16, F16, F32, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmDefault, 128, 64, 128, 16, 2, 4, 4, 1, S<4, 2>, S<8, 2>, S<8, 1, 2, 2>, S<2, 1, 32, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<8, 1, 2, 2>, S<2, 1, 64, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<0, 1, 2, 3, 4, 5>, 5, 4>, + DeviceBatchedGemmMultipleD_Dl< Row, Col, Empty_Tuple, Row, F16, F16, F32, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmDefault, 128, 64, 128, 16, 2, 4, 4, 1, S<2, 4>, S<2, 8>, S<8, 1, 2, 2>, S<2, 1, 32, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<8, 1, 2, 2>, S<2, 1, 64, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<0, 1, 2, 3, 4, 5>, 5, 4>, + // MPerBlock=64, NPerBlock=64 + DeviceBatchedGemmMultipleD_Dl< Row, Col, Empty_Tuple, Row, F16, F16, F32, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmDefault, 64, 64, 64, 8, 2, 4, 4, 1, S<4, 2>, S<4, 2>, S<4, 1, 2, 2>, S<2, 1, 32, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<4, 1, 2, 2>, S<2, 1, 32, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<0, 1, 2, 3, 4, 5>, 5, 4>, + DeviceBatchedGemmMultipleD_Dl< Row, Col, Empty_Tuple, Row, F16, F16, F32, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmDefault, 64, 64, 64, 8, 2, 4, 4, 1, S<2, 4>, S<2, 4>, S<4, 1, 2, 2>, S<2, 1, 32, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<4, 1, 2, 2>, S<2, 1, 32, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<0, 1, 2, 3, 4, 5>, 5, 4>, + DeviceBatchedGemmMultipleD_Dl< Row, Col, Empty_Tuple, Row, F16, F16, F32, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmDefault, 64, 64, 64, 8, 2, 4, 4, 1, S<8, 1>, S<4, 2>, S<4, 1, 2, 2>, S<2, 1, 32, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<4, 1, 2, 2>, S<2, 1, 32, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<0, 1, 2, 3, 4, 5>, 5, 4>, + DeviceBatchedGemmMultipleD_Dl< Row, Col, Empty_Tuple, Row, F16, F16, F32, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmDefault, 64, 64, 64, 8, 2, 4, 4, 1, S<4, 2>, S<8, 1>, S<4, 1, 2, 2>, S<2, 1, 32, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<4, 1, 2, 2>, S<2, 1, 32, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<0, 1, 2, 3, 4, 5>, 5, 4>, + // MPerBlock=16, NPerBlock=64 + DeviceBatchedGemmMultipleD_Dl< Row, Col, Empty_Tuple, Row, F16, F16, F32, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmDefault, 64, 16, 64, 16, 2, 1, 4, 1, S<4, 2>, S<4, 2>, S<4, 1, 1, 2>, S<4, 1, 16, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<4, 1, 4, 2>, S<4, 1, 16, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<0, 1, 2, 3, 4, 5>, 5, 4>, + // MPerBlock=64, NPerBlock=16 + DeviceBatchedGemmMultipleD_Dl< Row, Col, Empty_Tuple, Row, F16, F16, F32, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmDefault, 64, 64, 16, 16, 2, 4, 1, 1, S<4, 2>, S<4, 2>, S<4, 1, 4, 2>, S<4, 1, 16, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<4, 1, 1, 2>, S<4, 1, 16, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<0, 1, 2, 3, 4, 5>, 5, 1>, + // MPerBlock=16, NPerBlock=16 + DeviceBatchedGemmMultipleD_Dl< Row, Col, Empty_Tuple, Row, F16, F16, F32, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmDefault, 16, 16, 16, 16, 2, 2, 2, 1, S<2, 2>, S<2, 2>, S<4, 1, 4, 2>, S<4, 1, 4, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<4, 1, 4, 2>, S<4, 1, 4, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<0, 1, 2, 3, 4, 5>, 5, 2>, + DeviceBatchedGemmMultipleD_Dl< Row, Col, Empty_Tuple, Row, F16, F16, F32, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmDefault, 16, 16, 16, 16, 2, 2, 2, 1, S<1, 4>, S<1, 4>, S<4, 1, 4, 2>, S<4, 1, 4, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<4, 1, 4, 2>, S<4, 1, 4, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<0, 1, 2, 3, 4, 5>, 5, 2>, + // MPerBlock=8, NPerBlock=64 + DeviceBatchedGemmMultipleD_Dl< Row, Col, Empty_Tuple, Row, F16, F16, F32, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmDefault, 64, 8, 64, 32, 2, 1, 2, 1, S<4, 1>, S<8, 2>, S<4, 1, 1, 2>, S<8, 1, 8, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<8, 1, 4, 2>, S<4, 1, 16, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<0, 1, 2, 3, 4, 5>, 5, 2>, + DeviceBatchedGemmMultipleD_Dl< Row, Col, Empty_Tuple, Row, F16, F16, F32, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmDefault, 64, 8, 64, 32, 2, 1, 2, 1, S<2, 2>, S<8, 2>, S<4, 1, 1, 2>, S<8, 1, 8, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<8, 1, 4, 2>, S<4, 1, 16, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<0, 1, 2, 3, 4, 5>, 5, 2>, + // MPerBlock=64, NPerBlock=8 + DeviceBatchedGemmMultipleD_Dl< Row, Col, Empty_Tuple, Row, F16, F16, F32, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmDefault, 64, 64, 8, 32, 2, 2, 1, 1, S<8, 2>, S<4, 1>, S<8, 1, 4, 2>, S<4, 1, 16, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<4, 1, 1, 2>, S<8, 1, 8, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<0, 1, 2, 3, 4, 5>, 5, 1>, + DeviceBatchedGemmMultipleD_Dl< Row, Col, Empty_Tuple, Row, F16, F16, F32, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmDefault, 64, 64, 8, 32, 2, 2, 1, 1, S<8, 2>, S<2, 2>, S<8, 1, 4, 2>, S<4, 1, 16, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<4, 1, 1, 2>, S<8, 1, 8, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<0, 1, 2, 3, 4, 5>, 5, 1>, + // MPerBlock=8, NPerBlock=8 + DeviceBatchedGemmMultipleD_Dl< Row, Col, Empty_Tuple, Row, F16, F16, F32, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmDefault, 8, 8, 8, 4, 2, 1, 2, 1, S<4, 1>, S<2, 1>, S<4, 1, 1, 2>, S<1, 1, 8, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<4, 1, 1, 2>, S<1, 1, 8, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<0, 1, 2, 3, 4, 5>, 5, 2>, + DeviceBatchedGemmMultipleD_Dl< Row, Col, Empty_Tuple, Row, F16, F16, F32, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmDefault, 8, 8, 8, 4, 2, 1, 2, 1, S<1, 4>, S<1, 2>, S<4, 1, 1, 2>, S<1, 1, 8, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<4, 1, 1, 2>, S<1, 1, 8, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<0, 1, 2, 3, 4, 5>, 5, 2>, + DeviceBatchedGemmMultipleD_Dl< Row, Col, Empty_Tuple, Row, F16, F16, F32, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmDefault, 8, 8, 8, 4, 2, 2, 1, 1, S<2, 1>, S<4, 1>, S<4, 1, 1, 2>, S<1, 1, 8, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<4, 1, 1, 2>, S<1, 1, 8, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<0, 1, 2, 3, 4, 5>, 5, 1>, + DeviceBatchedGemmMultipleD_Dl< Row, Col, Empty_Tuple, Row, F16, F16, F32, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmDefault, 8, 8, 8, 4, 2, 2, 1, 1, S<1, 2>, S<1, 4>, S<4, 1, 1, 2>, S<1, 1, 8, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<4, 1, 1, 2>, S<1, 1, 8, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<0, 1, 2, 3, 4, 5>, 5, 1> + // clang-format on + >; + +void add_device_batched_gemm_multi_d_dl_f16_f16_f16_gmk_gnk_gmn_instances( + std::vector>>& instances) +{ + add_device_operation_instances( + instances, device_batched_gemm_multi_d_dl_f16_f16_f16_gmk_gnk_gmn_instances{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/batched_gemm_multi_d/device_batched_gemm_multi_d_dl_f16_f16_f16_gmk_gnk_gmn_irregular_instance.cpp b/library/src/tensor_operation_instance/gpu/batched_gemm_multi_d/device_batched_gemm_multi_d_dl_f16_f16_f16_gmk_gnk_gmn_irregular_instance.cpp new file mode 100644 index 0000000000000000000000000000000000000000..cf5502926b772e670ff8a2136eeafc17e548da8e --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/batched_gemm_multi_d/device_batched_gemm_multi_d_dl_f16_f16_f16_gmk_gnk_gmn_irregular_instance.cpp @@ -0,0 +1,83 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/device_batched_gemm_multi_d.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_batched_gemm_multiple_d_dl.hpp" +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +using F16 = ck::half_t; +using F32 = float; + +using Row = ck::tensor_layout::gemm::RowMajor; +using Col = ck::tensor_layout::gemm::ColumnMajor; + +template +using S = ck::Sequence; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; +using Empty_Tuple = ck::Tuple<>; + +static constexpr auto GemmMNPadding = ck::tensor_operation::device::GemmSpecialization::MNPadding; + +// Compilation parameters for a[m, k] * b[n, k] = c[m, n] +using device_batched_gemm_multi_d_dl_f16_f16_f16_gmk_gnk_gmn_irregular_instances = std::tuple< + // clang-format off + // ##########################| ALayout| BLayout| DsLayout| CLayout| AData| BData| AccData| DsData| CData| A| B| C| GEMM| Block| MPer| NPer| K0Per| K1| M1Per| N1Per| KPer| M11N11Thread| M11N11Thread| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| CThreadTransfer| CThreadTransfer| CThreadTransfer| + // ##########################| | | | | Type| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise| Specialization| Size| Block| Block| Block| | ThreadM111| ThreadN111| Thread| ClusterM110Xs| ClusterN110Xs| ThreadSliceLengths| ThreadClusterLengths| ThreadCluster| SrcAccess| SrcVectorTensor| SrcVectorTensor| DstVectorTensor| ThreadSliceLengths| ThreadClusterLengths| ThreadCluster| SrcAccess| SrcVectorTensor| SrcVectorTensor| DstVectorTensor| SrcDstAccess| SrcDstVectorDim| DstScalarPerVector| + // ##########################| | | | | | | | | | Operation| Operation| Operation| | | | | | | | | | | | K0_M0_M1_K1| K0_M0_M1_K1| ArrangeOrder| Order| Lengths_K0_M0_M1_K1| ContiguousDimOrder| Lengths_K0_M0_M1_K1| K0_M0_M1_K1| K0_M0_M1_K1| ArrangeOrder| Order| Lengths_K0_M0_M1_K1| ContiguousDimOrder| Lengths_K0_M0_M1_K1| Order| | | + // ##########################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + // MPerBlock=128, NPerBlock=128 + DeviceBatchedGemmMultipleD_Dl< Row, Col, Empty_Tuple, Row, F16, F16, F32, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmMNPadding, 256, 128, 128, 16, 2, 4, 4, 1, S<8, 2>, S<8, 2>, S<8, 1, 1, 2>, S<2, 1, 128, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<8, 1, 1, 2>, S<2, 1, 128, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<0, 1, 2, 3, 4, 5>, 5, 1>, + DeviceBatchedGemmMultipleD_Dl< Row, Col, Empty_Tuple, Row, F16, F16, F32, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmMNPadding, 256, 128, 128, 16, 2, 4, 4, 1, S<4, 4>, S<4, 4>, S<8, 1, 1, 2>, S<2, 1, 128, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<8, 1, 1, 2>, S<2, 1, 128, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<0, 1, 2, 3, 4, 5>, 5, 1>, + DeviceBatchedGemmMultipleD_Dl< Row, Col, Empty_Tuple, Row, F16, F16, F32, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmMNPadding, 256, 128, 128, 16, 2, 4, 4, 1, S<2, 8>, S<2, 8>, S<8, 1, 1, 2>, S<2, 1, 128, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<8, 1, 1, 2>, S<2, 1, 128, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<0, 1, 2, 3, 4, 5>, 5, 1>, + // MPerBlock=64, NPerBlock=64 + DeviceBatchedGemmMultipleD_Dl< Row, Col, Empty_Tuple, Row, F16, F16, F32, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmMNPadding, 64, 64, 64, 8, 2, 4, 4, 1, S<4, 2>, S<4, 2>, S<4, 1, 2, 2>, S<2, 1, 32, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<4, 1, 2, 2>, S<2, 1, 32, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<0, 1, 2, 3, 4, 5>, 5, 1>, + DeviceBatchedGemmMultipleD_Dl< Row, Col, Empty_Tuple, Row, F16, F16, F32, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmMNPadding, 64, 64, 64, 8, 2, 4, 4, 1, S<2, 4>, S<2, 4>, S<4, 1, 2, 2>, S<2, 1, 32, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<4, 1, 2, 2>, S<2, 1, 32, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<0, 1, 2, 3, 4, 5>, 5, 1>, + DeviceBatchedGemmMultipleD_Dl< Row, Col, Empty_Tuple, Row, F16, F16, F32, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmMNPadding, 64, 64, 64, 8, 2, 4, 4, 1, S<4, 2>, S<8, 1>, S<4, 1, 2, 2>, S<2, 1, 32, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<4, 1, 2, 2>, S<2, 1, 32, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<0, 1, 2, 3, 4, 5>, 5, 1>, + // MPerBlock=16, NPerBlock=64 + DeviceBatchedGemmMultipleD_Dl< Row, Col, Empty_Tuple, Row, F16, F16, F32, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmMNPadding, 64, 16, 64, 16, 2, 1, 4, 1, S<2, 4>, S<2, 4>, S<4, 1, 1, 2>, S<4, 1, 16, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<4, 1, 4, 2>, S<4, 1, 16, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<0, 1, 2, 3, 4, 5>, 5, 1>, + // MPerBlock=64, NPerBlock=16 + DeviceBatchedGemmMultipleD_Dl< Row, Col, Empty_Tuple, Row, F16, F16, F32, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmMNPadding, 64, 64, 16, 16, 2, 4, 1, 1, S<4, 2>, S<4, 2>, S<4, 1, 4, 2>, S<4, 1, 16, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<4, 1, 1, 2>, S<4, 1, 16, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<0, 1, 2, 3, 4, 5>, 5, 1>, + // MPerBlock=8, NPerBlock=64 + DeviceBatchedGemmMultipleD_Dl< Row, Col, Empty_Tuple, Row, F16, F16, F32, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmMNPadding, 64, 8, 64, 32, 2, 1, 2, 1, S<4, 1>, S<8, 2>, S<4, 1, 1, 2>, S<8, 1, 8, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<8, 1, 4, 2>, S<4, 1, 16, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<0, 1, 2, 3, 4, 5>, 5, 1>, + DeviceBatchedGemmMultipleD_Dl< Row, Col, Empty_Tuple, Row, F16, F16, F32, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmMNPadding, 64, 8, 64, 32, 2, 1, 2, 1, S<2, 2>, S<8, 2>, S<4, 1, 1, 2>, S<8, 1, 8, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<8, 1, 4, 2>, S<4, 1, 16, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<0, 1, 2, 3, 4, 5>, 5, 1>, + // MPerBlock=64, NPerBlock=8 + DeviceBatchedGemmMultipleD_Dl< Row, Col, Empty_Tuple, Row, F16, F16, F32, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmMNPadding, 64, 64, 8, 32, 2, 2, 1, 1, S<8, 2>, S<4, 1>, S<8, 1, 4, 2>, S<4, 1, 16, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<4, 1, 1, 2>, S<8, 1, 8, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<0, 1, 2, 3, 4, 5>, 5, 1>, + DeviceBatchedGemmMultipleD_Dl< Row, Col, Empty_Tuple, Row, F16, F16, F32, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmMNPadding, 64, 64, 8, 32, 2, 2, 1, 1, S<8, 2>, S<2, 2>, S<8, 1, 4, 2>, S<4, 1, 16, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<4, 1, 1, 2>, S<8, 1, 8, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<0, 1, 2, 3, 4, 5>, 5, 1>, + // MPerBlock=8, NPerBlock=8 + DeviceBatchedGemmMultipleD_Dl< Row, Col, Empty_Tuple, Row, F16, F16, F32, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmMNPadding, 8, 8, 8, 4, 2, 2, 1, 1, S<2, 1>, S<4, 1>, S<4, 1, 1, 2>, S<1, 1, 8, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<4, 1, 1, 2>, S<1, 1, 8, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<0, 1, 2, 3, 4, 5>, 5, 1>, + DeviceBatchedGemmMultipleD_Dl< Row, Col, Empty_Tuple, Row, F16, F16, F32, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmMNPadding, 8, 8, 8, 4, 2, 2, 1, 1, S<1, 2>, S<1, 4>, S<4, 1, 1, 2>, S<1, 1, 8, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<4, 1, 1, 2>, S<1, 1, 8, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<0, 1, 2, 3, 4, 5>, 5, 1> + // clang-format on + >; + +void add_device_batched_gemm_multi_d_dl_f16_f16_f16_gmk_gnk_gmn_irregular_instances( + std::vector>>& instances) +{ + add_device_operation_instances( + instances, device_batched_gemm_multi_d_dl_f16_f16_f16_gmk_gnk_gmn_irregular_instances{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/batched_gemm_multi_d/device_batched_gemm_multi_d_dl_i8_i8_i8_gkm_gkn_gmn_instance.cpp b/library/src/tensor_operation_instance/gpu/batched_gemm_multi_d/device_batched_gemm_multi_d_dl_i8_i8_i8_gkm_gkn_gmn_instance.cpp new file mode 100644 index 0000000000000000000000000000000000000000..43cdcc5f140fa792181893b4bc2e586c9350e087 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/batched_gemm_multi_d/device_batched_gemm_multi_d_dl_i8_i8_i8_gkm_gkn_gmn_instance.cpp @@ -0,0 +1,93 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/device_batched_gemm_multi_d.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_batched_gemm_multiple_d_dl.hpp" +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +using Row = ck::tensor_layout::gemm::RowMajor; +using Col = ck::tensor_layout::gemm::ColumnMajor; + +template +using S = ck::Sequence; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; +using Empty_Tuple = ck::Tuple<>; + +static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default; + +// Compilation parameters for a[k, m] * b[k, n] = c[m, n] +using device_batched_gemm_multi_d_dl_i8_i8_i8_gkm_gkn_gmn_instances = std::tuple< + // clang-format off + // ##########################| ALayout| BLayout| DsLayout| CLayout| AData| BData| AccData| DsData| CData| A| B| C| GEMM| Block| MPer| NPer| K0Per| K1| M1Per| N1Per| KPer| M11N11Thread| M11N11Thread| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| CThreadTransfer| CThreadTransfer| CThreadTransfer| + // ##########################| | | | | Type| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise| Specialization| Size| Block| Block| Block| | ThreadM111| ThreadN111| Thread| ClusterM110Xs| ClusterN110Xs| ThreadSliceLengths| ThreadClusterLengths| ThreadCluster| SrcAccess| SrcVectorTensor| SrcVectorTensor| DstVectorTensor| ThreadSliceLengths| ThreadClusterLengths| ThreadCluster| SrcAccess| SrcVectorTensor| SrcVectorTensor| DstVectorTensor| SrcDstAccess| SrcDstVectorDim| DstScalarPerVector| + // ##########################| | | | | | | | | | Operation| Operation| Operation| | | | | | | | | | | | K0_M0_M1_K1| K0_M0_M1_K1| ArrangeOrder| Order| Lengths_K0_M0_M1_K1| ContiguousDimOrder| Lengths_K0_M0_M1_K1| K0_M0_M1_K1| K0_M0_M1_K1| ArrangeOrder| Order| Lengths_K0_M0_M1_K1| ContiguousDimOrder| Lengths_K0_M0_M1_K1| Order| | | + // ##########################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + // MPerBlock=128, NPerBlock=128 + DeviceBatchedGemmMultipleD_Dl< Col, Row, Empty_Tuple, Row, int8_t, int8_t, int32_t, Empty_Tuple, int8_t, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 128, 128, 16, 4, 4, 4, 1, S<2, 8>, S<2, 8>, S<2, 1, 4, 4>, S<8, 1, 32, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 4>, S<2, 1, 4, 4>, S<8, 1, 32, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 4>, S<0, 1, 2, 3, 4, 5>, 5, 4>, + DeviceBatchedGemmMultipleD_Dl< Col, Row, Empty_Tuple, Row, int8_t, int8_t, int32_t, Empty_Tuple, int8_t, PassThrough, PassThrough, PassThrough, GemmDefault, 128, 128, 128, 16, 4, 4, 8, 1, S<8, 2>, S<4, 2>, S<2, 1, 8, 4>, S<8, 1, 16, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 4>, S<2, 1, 8, 4>, S<8, 1, 16, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 4>, S<0, 1, 2, 3, 4, 5>, 5, 4>, + DeviceBatchedGemmMultipleD_Dl< Col, Row, Empty_Tuple, Row, int8_t, int8_t, int32_t, Empty_Tuple, int8_t, PassThrough, PassThrough, PassThrough, GemmDefault, 128, 128, 128, 16, 4, 4, 8, 1, S<2, 8>, S<2, 4>, S<2, 1, 8, 4>, S<8, 1, 16, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 4>, S<2, 1, 8, 4>, S<8, 1, 16, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 4>, S<0, 1, 2, 3, 4, 5>, 5, 4>, + // MPerBlock=128, NPerBlock=64 + DeviceBatchedGemmMultipleD_Dl< Col, Row, Empty_Tuple, Row, int8_t, int8_t, int32_t, Empty_Tuple, int8_t, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 128, 64, 16, 4, 4, 2, 1, S<2, 8>, S<2, 8>, S<2, 1, 4, 4>, S<8, 1, 32, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 4>, S<2, 1, 4, 4>, S<8, 1, 16, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 4>, S<0, 1, 2, 3, 4, 5>, 5, 2>, + // MPerBlock=64, NPerBlock=128 + DeviceBatchedGemmMultipleD_Dl< Col, Row, Empty_Tuple, Row, int8_t, int8_t, int32_t, Empty_Tuple, int8_t, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 64, 128, 16, 4, 2, 4, 1, S<2, 8>, S<2, 8>, S<2, 1, 4, 4>, S<8, 1, 16, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 4>, S<2, 1, 4, 4>, S<8, 1, 32, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 4>, S<0, 1, 2, 3, 4, 5>, 5, 4>, + // MPerBlock=64, NPerBlock=64 + DeviceBatchedGemmMultipleD_Dl< Col, Row, Empty_Tuple, Row, int8_t, int8_t, int32_t, Empty_Tuple, int8_t, PassThrough, PassThrough, PassThrough, GemmDefault, 64, 64, 64, 8, 4, 4, 4, 1, S<4, 2>, S<4, 2>, S<2, 1, 4, 4>, S<4, 1, 16, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 4>, S<2, 1, 4, 4>, S<4, 1, 16, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 4>, S<0, 1, 2, 3, 4, 5>, 5, 4>, + DeviceBatchedGemmMultipleD_Dl< Col, Row, Empty_Tuple, Row, int8_t, int8_t, int32_t, Empty_Tuple, int8_t, PassThrough, PassThrough, PassThrough, GemmDefault, 64, 64, 64, 8, 4, 4, 4, 1, S<2, 4>, S<2, 4>, S<2, 1, 4, 4>, S<4, 1, 16, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 4>, S<2, 1, 4, 4>, S<4, 1, 16, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 4>, S<0, 1, 2, 3, 4, 5>, 5, 4>, + DeviceBatchedGemmMultipleD_Dl< Col, Row, Empty_Tuple, Row, int8_t, int8_t, int32_t, Empty_Tuple, int8_t, PassThrough, PassThrough, PassThrough, GemmDefault, 64, 64, 64, 8, 4, 4, 4, 1, S<8, 1>, S<4, 2>, S<2, 1, 4, 4>, S<4, 1, 16, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 4>, S<2, 1, 4, 4>, S<4, 1, 16, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 4>, S<0, 1, 2, 3, 4, 5>, 5, 4>, + DeviceBatchedGemmMultipleD_Dl< Col, Row, Empty_Tuple, Row, int8_t, int8_t, int32_t, Empty_Tuple, int8_t, PassThrough, PassThrough, PassThrough, GemmDefault, 64, 64, 64, 8, 4, 4, 4, 1, S<4, 2>, S<8, 1>, S<2, 1, 4, 4>, S<4, 1, 16, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 4>, S<2, 1, 4, 4>, S<4, 1, 16, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 4>, S<0, 1, 2, 3, 4, 5>, 5, 4>, + // MPerBlock=32, NPerBlock=32 + DeviceBatchedGemmMultipleD_Dl< Col, Row, Empty_Tuple, Row, int8_t, int8_t, int32_t, Empty_Tuple, int8_t, PassThrough, PassThrough, PassThrough, GemmDefault, 32, 32, 32, 8, 4, 4, 2, 1, S<2, 2>, S<2, 4>, S<2, 1, 4, 4>, S<4, 1, 8, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 4>, S<2, 1, 4, 4>, S<4, 1, 8, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 4>, S<0, 1, 2, 3, 4, 5>, 5, 2>, + // MPerBlock=16, NPerBlock=64 + DeviceBatchedGemmMultipleD_Dl< Col, Row, Empty_Tuple, Row, int8_t, int8_t, int32_t, Empty_Tuple, int8_t, PassThrough, PassThrough, PassThrough, GemmDefault, 64, 16, 64, 16, 4, 1, 4, 1, S<4, 2>, S<4, 2>, S<1, 1, 4, 4>, S<16, 1, 4, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 4>, S<4, 1, 4, 4>, S<4, 1, 16, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 4>, S<0, 1, 2, 3, 4, 5>, 5, 4>, + DeviceBatchedGemmMultipleD_Dl< Col, Row, Empty_Tuple, Row, int8_t, int8_t, int32_t, Empty_Tuple, int8_t, PassThrough, PassThrough, PassThrough, GemmDefault, 64, 16, 64, 16, 4, 1, 4, 1, S<2, 4>, S<2, 4>, S<1, 1, 4, 4>, S<16, 1, 4, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 4>, S<4, 1, 4, 4>, S<4, 1, 16, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 4>, S<0, 1, 2, 3, 4, 5>, 5, 4>, + // MPerBlock=64, NPerBlock=16 + DeviceBatchedGemmMultipleD_Dl< Col, Row, Empty_Tuple, Row, int8_t, int8_t, int32_t, Empty_Tuple, int8_t, PassThrough, PassThrough, PassThrough, GemmDefault, 64, 64, 16, 16, 4, 4, 1, 1, S<4, 2>, S<4, 2>, S<4, 1, 4, 4>, S<4, 1, 16, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 4>, S<1, 1, 4, 4>, S<16, 1, 4, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 4>, S<0, 1, 2, 3, 4, 5>, 5, 1>, + DeviceBatchedGemmMultipleD_Dl< Col, Row, Empty_Tuple, Row, int8_t, int8_t, int32_t, Empty_Tuple, int8_t, PassThrough, PassThrough, PassThrough, GemmDefault, 64, 64, 16, 16, 4, 4, 1, 1, S<2, 4>, S<2, 4>, S<4, 1, 4, 4>, S<4, 1, 16, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 4>, S<1, 1, 4, 4>, S<16, 1, 4, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 4>, S<0, 1, 2, 3, 4, 5>, 5, 1>, + // MPerBlock=16, NPerBlock=16 + DeviceBatchedGemmMultipleD_Dl< Col, Row, Empty_Tuple, Row, int8_t, int8_t, int32_t, Empty_Tuple, int8_t, PassThrough, PassThrough, PassThrough, GemmDefault, 16, 16, 16, 16, 4, 2, 2, 1, S<4, 1>, S<4, 1>, S<4, 1, 4, 4>, S<4, 1, 4, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 4>, S<4, 1, 4, 4>, S<4, 1, 4, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 4>, S<0, 1, 2, 3, 4, 5>, 5, 2>, + // MPerBlock=8, NPerBlock=64 + DeviceBatchedGemmMultipleD_Dl< Col, Row, Empty_Tuple, Row, int8_t, int8_t, int32_t, Empty_Tuple, int8_t, PassThrough, PassThrough, PassThrough, GemmDefault, 64, 8, 64, 32, 4, 1, 2, 1, S<4, 1>, S<8, 2>, S<1, 1, 4, 4>, S<32, 1, 2, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 4>, S<8, 1, 4, 4>, S<4, 1, 16, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 4>, S<0, 1, 2, 3, 4, 5>, 5, 2>, + DeviceBatchedGemmMultipleD_Dl< Col, Row, Empty_Tuple, Row, int8_t, int8_t, int32_t, Empty_Tuple, int8_t, PassThrough, PassThrough, PassThrough, GemmDefault, 64, 8, 64, 32, 4, 1, 2, 1, S<2, 2>, S<8, 2>, S<1, 1, 4, 4>, S<32, 1, 2, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 4>, S<8, 1, 4, 4>, S<4, 1, 16, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 4>, S<0, 1, 2, 3, 4, 5>, 5, 2>, + // MPerBlock=64, NPerBlock=8 + DeviceBatchedGemmMultipleD_Dl< Col, Row, Empty_Tuple, Row, int8_t, int8_t, int32_t, Empty_Tuple, int8_t, PassThrough, PassThrough, PassThrough, GemmDefault, 64, 64, 8, 32, 4, 2, 1, 1, S<8, 2>, S<4, 1>, S<8, 1, 4, 4>, S<4, 1, 16, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 4>, S<1, 1, 4, 4>, S<32, 1, 2, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 4>, S<0, 1, 2, 3, 4, 5>, 5, 1>, + DeviceBatchedGemmMultipleD_Dl< Col, Row, Empty_Tuple, Row, int8_t, int8_t, int32_t, Empty_Tuple, int8_t, PassThrough, PassThrough, PassThrough, GemmDefault, 64, 64, 8, 32, 4, 2, 1, 1, S<8, 2>, S<2, 2>, S<8, 1, 4, 4>, S<4, 1, 16, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 4>, S<1, 1, 4, 4>, S<32, 1, 2, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 4>, S<0, 1, 2, 3, 4, 5>, 5, 1>, + // MPerBlock=8, NPerBlock=8 + DeviceBatchedGemmMultipleD_Dl< Col, Row, Empty_Tuple, Row, int8_t, int8_t, int32_t, Empty_Tuple, int8_t, PassThrough, PassThrough, PassThrough, GemmDefault, 8, 8, 8, 4, 4, 1, 2, 1, S<4, 1>, S<2, 1>, S<1, 1, 4, 4>, S<4, 1, 2, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 4>, S<1, 1, 4, 4>, S<4, 1, 2, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 4>, S<0, 1, 2, 3, 4, 5>, 5, 2>, + DeviceBatchedGemmMultipleD_Dl< Col, Row, Empty_Tuple, Row, int8_t, int8_t, int32_t, Empty_Tuple, int8_t, PassThrough, PassThrough, PassThrough, GemmDefault, 8, 8, 8, 4, 4, 1, 2, 1, S<1, 4>, S<1, 2>, S<1, 1, 4, 4>, S<4, 1, 2, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 4>, S<1, 1, 4, 4>, S<4, 1, 2, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 4>, S<0, 1, 2, 3, 4, 5>, 5, 2>, + DeviceBatchedGemmMultipleD_Dl< Col, Row, Empty_Tuple, Row, int8_t, int8_t, int32_t, Empty_Tuple, int8_t, PassThrough, PassThrough, PassThrough, GemmDefault, 8, 8, 8, 4, 4, 2, 1, 1, S<2, 1>, S<4, 1>, S<1, 1, 4, 4>, S<4, 1, 2, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 4>, S<1, 1, 4, 4>, S<4, 1, 2, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 4>, S<0, 1, 2, 3, 4, 5>, 5, 1>, + DeviceBatchedGemmMultipleD_Dl< Col, Row, Empty_Tuple, Row, int8_t, int8_t, int32_t, Empty_Tuple, int8_t, PassThrough, PassThrough, PassThrough, GemmDefault, 8, 8, 8, 4, 4, 2, 1, 1, S<1, 2>, S<1, 4>, S<1, 1, 4, 4>, S<4, 1, 2, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 4>, S<1, 1, 4, 4>, S<4, 1, 2, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 4>, S<0, 1, 2, 3, 4, 5>, 5, 1> + // clang-format on + >; + +void add_device_batched_gemm_multi_d_dl_i8_i8_i8_gkm_gkn_gmn_instances( + std::vector>>& instances) +{ + add_device_operation_instances(instances, + device_batched_gemm_multi_d_dl_i8_i8_i8_gkm_gkn_gmn_instances{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/batched_gemm_multi_d/device_batched_gemm_multi_d_dl_i8_i8_i8_gkm_gkn_gmn_irregular_instance.cpp b/library/src/tensor_operation_instance/gpu/batched_gemm_multi_d/device_batched_gemm_multi_d_dl_i8_i8_i8_gkm_gkn_gmn_irregular_instance.cpp new file mode 100644 index 0000000000000000000000000000000000000000..617877499d1c97bf99535f13ff057946487032af --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/batched_gemm_multi_d/device_batched_gemm_multi_d_dl_i8_i8_i8_gkm_gkn_gmn_irregular_instance.cpp @@ -0,0 +1,90 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/device_batched_gemm_multi_d.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_batched_gemm_multiple_d_dl.hpp" +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +using Row = ck::tensor_layout::gemm::RowMajor; +using Col = ck::tensor_layout::gemm::ColumnMajor; + +template +using S = ck::Sequence; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; +using Empty_Tuple = ck::Tuple<>; + +static constexpr auto GemmMNPadding = ck::tensor_operation::device::GemmSpecialization::MNPadding; + +// Compilation parameters for a[k, m] * b[k, n] = c[m, n] +using device_batched_gemm_multi_d_dl_i8_i8_i8_gkm_gkn_gmn_irregular_instances = std::tuple< + // clang-format off + // ##########################| ALayout| BLayout| DsLayout| CLayout| AData| BData| AccData| DsData| CData| A| B| C| GEMM| Block| MPer| NPer| K0Per| K1| M1Per| N1Per| KPer| M11N11Thread| M11N11Thread| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| CThreadTransfer| CThreadTransfer| CThreadTransfer| + // ##########################| | | | | Type| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise| Specialization| Size| Block| Block| Block| | ThreadM111| ThreadN111| Thread| ClusterM110Xs| ClusterN110Xs| ThreadSliceLengths| ThreadClusterLengths| ThreadCluster| SrcAccess| SrcVectorTensor| SrcVectorTensor| DstVectorTensor| ThreadSliceLengths| ThreadClusterLengths| ThreadCluster| SrcAccess| SrcVectorTensor| SrcVectorTensor| DstVectorTensor| SrcDstAccess| SrcDstVectorDim| DstScalarPerVector| + // ##########################| | | | | | | | | | Operation| Operation| Operation| | | | | | | | | | | | K0_M0_M1_K1| K0_M0_M1_K1| ArrangeOrder| Order| Lengths_K0_M0_M1_K1| ContiguousDimOrder| Lengths_K0_M0_M1_K1| K0_M0_M1_K1| K0_M0_M1_K1| ArrangeOrder| Order| Lengths_K0_M0_M1_K1| ContiguousDimOrder| Lengths_K0_M0_M1_K1| Order| | | + // ##########################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + // MPerBlock=128, NPerBlock=128 + DeviceBatchedGemmMultipleD_Dl< Col, Row, Empty_Tuple, Row, int8_t, int8_t, int32_t, Empty_Tuple, int8_t, PassThrough, PassThrough, PassThrough, GemmMNPadding, 128, 128, 128, 16, 4, 4, 8, 1, S<8, 2>, S<4, 2>, S<2, 1, 8, 4>, S<8, 1, 16, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 4>, S<2, 1, 8, 4>, S<8, 1, 16, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 4>, S<0, 1, 2, 3, 4, 5>, 5, 1>, + DeviceBatchedGemmMultipleD_Dl< Col, Row, Empty_Tuple, Row, int8_t, int8_t, int32_t, Empty_Tuple, int8_t, PassThrough, PassThrough, PassThrough, GemmMNPadding, 128, 128, 128, 16, 4, 4, 8, 1, S<4, 4>, S<4, 2>, S<2, 1, 8, 4>, S<8, 1, 16, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 4>, S<2, 1, 8, 4>, S<8, 1, 16, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 4>, S<0, 1, 2, 3, 4, 5>, 5, 1>, + DeviceBatchedGemmMultipleD_Dl< Col, Row, Empty_Tuple, Row, int8_t, int8_t, int32_t, Empty_Tuple, int8_t, PassThrough, PassThrough, PassThrough, GemmMNPadding, 128, 128, 128, 16, 4, 4, 8, 1, S<2, 8>, S<2, 4>, S<2, 1, 8, 4>, S<8, 1, 16, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 4>, S<2, 1, 8, 4>, S<8, 1, 16, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 4>, S<0, 1, 2, 3, 4, 5>, 5, 1>, + // MPerBlock=128, NPerBlock=64 + DeviceBatchedGemmMultipleD_Dl< Col, Row, Empty_Tuple, Row, int8_t, int8_t, int32_t, Empty_Tuple, int8_t, PassThrough, PassThrough, PassThrough, GemmMNPadding, 256, 128, 64, 16, 4, 4, 2, 1, S<4, 4>, S<4, 4>, S<2, 1, 4, 4>, S<8, 1, 32, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 4>, S<2, 1, 4, 4>, S<8, 1, 16, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 4>, S<0, 1, 2, 3, 4, 5>, 5, 1>, + DeviceBatchedGemmMultipleD_Dl< Col, Row, Empty_Tuple, Row, int8_t, int8_t, int32_t, Empty_Tuple, int8_t, PassThrough, PassThrough, PassThrough, GemmMNPadding, 256, 128, 64, 16, 4, 4, 2, 1, S<2, 8>, S<2, 8>, S<2, 1, 4, 4>, S<8, 1, 32, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 4>, S<2, 1, 4, 4>, S<8, 1, 16, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 4>, S<0, 1, 2, 3, 4, 5>, 5, 1>, + // MPerBlock=64, NPerBlock=128 + DeviceBatchedGemmMultipleD_Dl< Col, Row, Empty_Tuple, Row, int8_t, int8_t, int32_t, Empty_Tuple, int8_t, PassThrough, PassThrough, PassThrough, GemmMNPadding, 256, 64, 128, 16, 4, 2, 4, 1, S<4, 4>, S<4, 4>, S<2, 1, 4, 4>, S<8, 1, 16, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 4>, S<2, 1, 4, 4>, S<8, 1, 32, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 4>, S<0, 1, 2, 3, 4, 5>, 5, 1>, + DeviceBatchedGemmMultipleD_Dl< Col, Row, Empty_Tuple, Row, int8_t, int8_t, int32_t, Empty_Tuple, int8_t, PassThrough, PassThrough, PassThrough, GemmMNPadding, 256, 64, 128, 16, 4, 2, 4, 1, S<2, 8>, S<2, 8>, S<2, 1, 4, 4>, S<8, 1, 16, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 4>, S<2, 1, 4, 4>, S<8, 1, 32, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 4>, S<0, 1, 2, 3, 4, 5>, 5, 1>, + // MPerBlock=64, NPerBlock=64 + DeviceBatchedGemmMultipleD_Dl< Col, Row, Empty_Tuple, Row, int8_t, int8_t, int32_t, Empty_Tuple, int8_t, PassThrough, PassThrough, PassThrough, GemmMNPadding, 64, 64, 64, 8, 4, 4, 4, 1, S<4, 2>, S<4, 2>, S<2, 1, 4, 4>, S<4, 1, 16, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 4>, S<2, 1, 4, 4>, S<4, 1, 16, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 4>, S<0, 1, 2, 3, 4, 5>, 5, 1>, + DeviceBatchedGemmMultipleD_Dl< Col, Row, Empty_Tuple, Row, int8_t, int8_t, int32_t, Empty_Tuple, int8_t, PassThrough, PassThrough, PassThrough, GemmMNPadding, 64, 64, 64, 8, 4, 4, 4, 1, S<2, 4>, S<2, 4>, S<2, 1, 4, 4>, S<4, 1, 16, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 4>, S<2, 1, 4, 4>, S<4, 1, 16, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 4>, S<0, 1, 2, 3, 4, 5>, 5, 1>, + DeviceBatchedGemmMultipleD_Dl< Col, Row, Empty_Tuple, Row, int8_t, int8_t, int32_t, Empty_Tuple, int8_t, PassThrough, PassThrough, PassThrough, GemmMNPadding, 64, 64, 64, 8, 4, 4, 4, 1, S<8, 1>, S<4, 2>, S<2, 1, 4, 4>, S<4, 1, 16, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 4>, S<2, 1, 4, 4>, S<4, 1, 16, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 4>, S<0, 1, 2, 3, 4, 5>, 5, 1>, + DeviceBatchedGemmMultipleD_Dl< Col, Row, Empty_Tuple, Row, int8_t, int8_t, int32_t, Empty_Tuple, int8_t, PassThrough, PassThrough, PassThrough, GemmMNPadding, 64, 64, 64, 8, 4, 4, 4, 1, S<4, 2>, S<8, 1>, S<2, 1, 4, 4>, S<4, 1, 16, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 4>, S<2, 1, 4, 4>, S<4, 1, 16, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 4>, S<0, 1, 2, 3, 4, 5>, 5, 1>, + // MPerBlock=32, NPerBlock=32 + DeviceBatchedGemmMultipleD_Dl< Col, Row, Empty_Tuple, Row, int8_t, int8_t, int32_t, Empty_Tuple, int8_t, PassThrough, PassThrough, PassThrough, GemmMNPadding, 32, 32, 32, 8, 4, 2, 4, 1, S<4, 2>, S<2, 2>, S<2, 1, 4, 4>, S<4, 1, 8, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 4>, S<2, 1, 4, 4>, S<4, 1, 8, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 4>, S<0, 1, 2, 3, 4, 5>, 5, 1>, + DeviceBatchedGemmMultipleD_Dl< Col, Row, Empty_Tuple, Row, int8_t, int8_t, int32_t, Empty_Tuple, int8_t, PassThrough, PassThrough, PassThrough, GemmMNPadding, 32, 32, 32, 8, 4, 4, 2, 1, S<2, 2>, S<4, 2>, S<2, 1, 4, 4>, S<4, 1, 8, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 4>, S<2, 1, 4, 4>, S<4, 1, 8, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 4>, S<0, 1, 2, 3, 4, 5>, 5, 1>, + DeviceBatchedGemmMultipleD_Dl< Col, Row, Empty_Tuple, Row, int8_t, int8_t, int32_t, Empty_Tuple, int8_t, PassThrough, PassThrough, PassThrough, GemmMNPadding, 32, 32, 32, 8, 4, 4, 2, 1, S<2, 2>, S<2, 4>, S<2, 1, 4, 4>, S<4, 1, 8, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 4>, S<2, 1, 4, 4>, S<4, 1, 8, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 4>, S<0, 1, 2, 3, 4, 5>, 5, 1>, + // MPerBlock=16, NPerBlock=16 + DeviceBatchedGemmMultipleD_Dl< Col, Row, Empty_Tuple, Row, int8_t, int8_t, int32_t, Empty_Tuple, int8_t, PassThrough, PassThrough, PassThrough, GemmMNPadding, 16, 16, 16, 16, 4, 2, 2, 1, S<2, 2>, S<2, 2>, S<4, 1, 4, 4>, S<4, 1, 4, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 4>, S<4, 1, 4, 4>, S<4, 1, 4, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 4>, S<0, 1, 2, 3, 4, 5>, 5, 1>, + DeviceBatchedGemmMultipleD_Dl< Col, Row, Empty_Tuple, Row, int8_t, int8_t, int32_t, Empty_Tuple, int8_t, PassThrough, PassThrough, PassThrough, GemmMNPadding, 16, 16, 16, 16, 4, 2, 2, 1, S<4, 1>, S<4, 1>, S<4, 1, 4, 4>, S<4, 1, 4, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 4>, S<4, 1, 4, 4>, S<4, 1, 4, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 4>, S<0, 1, 2, 3, 4, 5>, 5, 1>, + // MPerBlock=8, NPerBlock=64 + DeviceBatchedGemmMultipleD_Dl< Col, Row, Empty_Tuple, Row, int8_t, int8_t, int32_t, Empty_Tuple, int8_t, PassThrough, PassThrough, PassThrough, GemmMNPadding, 64, 8, 64, 32, 4, 1, 2, 1, S<2, 2>, S<8, 2>, S<1, 1, 4, 4>, S<32, 1, 2, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 4>, S<8, 1, 4, 4>, S<4, 1, 16, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 4>, S<0, 1, 2, 3, 4, 5>, 5, 1>, + // MPerBlock=64, NPerBlock=8 + DeviceBatchedGemmMultipleD_Dl< Col, Row, Empty_Tuple, Row, int8_t, int8_t, int32_t, Empty_Tuple, int8_t, PassThrough, PassThrough, PassThrough, GemmMNPadding, 64, 64, 8, 32, 4, 2, 1, 1, S<8, 2>, S<2, 2>, S<8, 1, 4, 4>, S<4, 1, 16, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 4>, S<1, 1, 4, 4>, S<32, 1, 2, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 4>, S<0, 1, 2, 3, 4, 5>, 5, 1>, + // MPerBlock=8, NPerBlock=8 + DeviceBatchedGemmMultipleD_Dl< Col, Row, Empty_Tuple, Row, int8_t, int8_t, int32_t, Empty_Tuple, int8_t, PassThrough, PassThrough, PassThrough, GemmMNPadding, 8, 8, 8, 4, 4, 1, 2, 1, S<4, 1>, S<2, 1>, S<1, 1, 4, 4>, S<4, 1, 2, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 4>, S<1, 1, 4, 4>, S<4, 1, 2, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 4>, S<0, 1, 2, 3, 4, 5>, 5, 1>, + DeviceBatchedGemmMultipleD_Dl< Col, Row, Empty_Tuple, Row, int8_t, int8_t, int32_t, Empty_Tuple, int8_t, PassThrough, PassThrough, PassThrough, GemmMNPadding, 8, 8, 8, 4, 4, 1, 2, 1, S<1, 4>, S<1, 2>, S<1, 1, 4, 4>, S<4, 1, 2, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 4>, S<1, 1, 4, 4>, S<4, 1, 2, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 4>, S<0, 1, 2, 3, 4, 5>, 5, 1>, + DeviceBatchedGemmMultipleD_Dl< Col, Row, Empty_Tuple, Row, int8_t, int8_t, int32_t, Empty_Tuple, int8_t, PassThrough, PassThrough, PassThrough, GemmMNPadding, 8, 8, 8, 4, 4, 2, 1, 1, S<2, 1>, S<4, 1>, S<1, 1, 4, 4>, S<4, 1, 2, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 4>, S<1, 1, 4, 4>, S<4, 1, 2, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 4>, S<0, 1, 2, 3, 4, 5>, 5, 1>, + DeviceBatchedGemmMultipleD_Dl< Col, Row, Empty_Tuple, Row, int8_t, int8_t, int32_t, Empty_Tuple, int8_t, PassThrough, PassThrough, PassThrough, GemmMNPadding, 8, 8, 8, 4, 4, 2, 1, 1, S<1, 2>, S<1, 4>, S<1, 1, 4, 4>, S<4, 1, 2, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 4>, S<1, 1, 4, 4>, S<4, 1, 2, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 4>, S<0, 1, 2, 3, 4, 5>, 5, 1> + // clang-format on + >; + +void add_device_batched_gemm_multi_d_dl_i8_i8_i8_gkm_gkn_gmn_irregular_instances( + std::vector>>& instances) +{ + add_device_operation_instances( + instances, device_batched_gemm_multi_d_dl_i8_i8_i8_gkm_gkn_gmn_irregular_instances{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/batched_gemm_multi_d/device_batched_gemm_multi_d_dl_i8_i8_i8_gkm_gnk_gmn_instance.cpp b/library/src/tensor_operation_instance/gpu/batched_gemm_multi_d/device_batched_gemm_multi_d_dl_i8_i8_i8_gkm_gnk_gmn_instance.cpp new file mode 100644 index 0000000000000000000000000000000000000000..1f58aecf5084b0993fc6f8e75e7b8a406cad9382 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/batched_gemm_multi_d/device_batched_gemm_multi_d_dl_i8_i8_i8_gkm_gnk_gmn_instance.cpp @@ -0,0 +1,93 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/device_batched_gemm_multi_d.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_batched_gemm_multiple_d_dl.hpp" +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +using Row = ck::tensor_layout::gemm::RowMajor; +using Col = ck::tensor_layout::gemm::ColumnMajor; + +template +using S = ck::Sequence; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; +using Empty_Tuple = ck::Tuple<>; + +static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default; + +// Compilation parameters for a[k, m] * b[n, k] = c[m, n] +using device_batched_gemm_multi_d_dl_i8_i8_i8_gkm_gnk_gmn_instances = std::tuple< + // clang-format off + // ##########################| ALayout| BLayout| DsLayout| CLayout| AData| BData| AccData| DsData| CData| A| B| C| GEMM| Block| MPer| NPer| K0Per| K1| M1Per| N1Per| KPer| M11N11Thread| M11N11Thread| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| CThreadTransfer| CThreadTransfer| CThreadTransfer| + // ##########################| | | | | Type| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise| Specialization| Size| Block| Block| Block| | ThreadM111| ThreadN111| Thread| ClusterM110Xs| ClusterN110Xs| ThreadSliceLengths| ThreadClusterLengths| ThreadCluster| SrcAccess| SrcVectorTensor| SrcVectorTensor| DstVectorTensor| ThreadSliceLengths| ThreadClusterLengths| ThreadCluster| SrcAccess| SrcVectorTensor| SrcVectorTensor| DstVectorTensor| SrcDstAccess| SrcDstVectorDim| DstScalarPerVector| + // ##########################| | | | | | | | | | Operation| Operation| Operation| | | | | | | | | | | | K0_M0_M1_K1| K0_M0_M1_K1| ArrangeOrder| Order| Lengths_K0_M0_M1_K1| ContiguousDimOrder| Lengths_K0_M0_M1_K1| K0_M0_M1_K1| K0_M0_M1_K1| ArrangeOrder| Order| Lengths_K0_M0_M1_K1| ContiguousDimOrder| Lengths_K0_M0_M1_K1| Order| | | + // ##########################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + // MPerBlock=128, NPerBlock=128 + DeviceBatchedGemmMultipleD_Dl< Col, Col, Empty_Tuple, Row, int8_t, int8_t, int32_t, Empty_Tuple, int8_t, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 128, 128, 16, 4, 4, 4, 1, S<2, 8>, S<2, 8>, S<2, 1, 4, 4>, S<8, 1, 32, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 4>, S<8, 1, 1, 4>, S<2, 1, 128, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 4>, S<1, 2, 0, 3>, S<1, 1, 1, 4>, S<0, 1, 2, 3, 4, 5>, 5, 4>, + DeviceBatchedGemmMultipleD_Dl< Col, Col, Empty_Tuple, Row, int8_t, int8_t, int32_t, Empty_Tuple, int8_t, PassThrough, PassThrough, PassThrough, GemmDefault, 128, 128, 128, 16, 4, 4, 8, 1, S<8, 2>, S<4, 2>, S<2, 1, 8, 4>, S<8, 1, 16, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 4>, S<8, 1, 2, 4>, S<2, 1, 64, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 4>, S<1, 2, 0, 3>, S<1, 1, 1, 4>, S<0, 1, 2, 3, 4, 5>, 5, 4>, + DeviceBatchedGemmMultipleD_Dl< Col, Col, Empty_Tuple, Row, int8_t, int8_t, int32_t, Empty_Tuple, int8_t, PassThrough, PassThrough, PassThrough, GemmDefault, 128, 128, 128, 16, 4, 4, 8, 1, S<2, 8>, S<2, 4>, S<2, 1, 8, 4>, S<8, 1, 16, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 4>, S<8, 1, 2, 4>, S<2, 1, 64, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 4>, S<1, 2, 0, 3>, S<1, 1, 1, 4>, S<0, 1, 2, 3, 4, 5>, 5, 4>, + // MPerBlock=128, NPerBlock=64 + DeviceBatchedGemmMultipleD_Dl< Col, Col, Empty_Tuple, Row, int8_t, int8_t, int32_t, Empty_Tuple, int8_t, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 128, 64, 16, 4, 4, 2, 1, S<2, 8>, S<2, 8>, S<2, 1, 4, 4>, S<8, 1, 32, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 4>, S<8, 1, 1, 4>, S<2, 1, 64, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 4>, S<1, 2, 0, 3>, S<1, 1, 1, 4>, S<0, 1, 2, 3, 4, 5>, 5, 2>, + // MPerBlock=64, NPerBlock=128 + DeviceBatchedGemmMultipleD_Dl< Col, Col, Empty_Tuple, Row, int8_t, int8_t, int32_t, Empty_Tuple, int8_t, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 64, 128, 16, 4, 2, 4, 1, S<4, 4>, S<4, 4>, S<2, 1, 4, 4>, S<8, 1, 16, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 4>, S<8, 1, 1, 4>, S<2, 1, 128, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 4>, S<1, 2, 0, 3>, S<1, 1, 1, 4>, S<0, 1, 2, 3, 4, 5>, 5, 4>, + // MPerBlock=64, NPerBlock=64 + DeviceBatchedGemmMultipleD_Dl< Col, Col, Empty_Tuple, Row, int8_t, int8_t, int32_t, Empty_Tuple, int8_t, PassThrough, PassThrough, PassThrough, GemmDefault, 64, 64, 64, 8, 4, 4, 4, 1, S<4, 2>, S<4, 2>, S<2, 1, 4, 4>, S<4, 1, 16, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 4>, S<4, 1, 2, 4>, S<2, 1, 32, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 4>, S<1, 2, 0, 3>, S<1, 1, 1, 4>, S<0, 1, 2, 3, 4, 5>, 5, 4>, + DeviceBatchedGemmMultipleD_Dl< Col, Col, Empty_Tuple, Row, int8_t, int8_t, int32_t, Empty_Tuple, int8_t, PassThrough, PassThrough, PassThrough, GemmDefault, 64, 64, 64, 8, 4, 4, 4, 1, S<2, 4>, S<2, 4>, S<2, 1, 4, 4>, S<4, 1, 16, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 4>, S<4, 1, 2, 4>, S<2, 1, 32, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 4>, S<1, 2, 0, 3>, S<1, 1, 1, 4>, S<0, 1, 2, 3, 4, 5>, 5, 4>, + DeviceBatchedGemmMultipleD_Dl< Col, Col, Empty_Tuple, Row, int8_t, int8_t, int32_t, Empty_Tuple, int8_t, PassThrough, PassThrough, PassThrough, GemmDefault, 64, 64, 64, 8, 4, 4, 4, 1, S<8, 1>, S<4, 2>, S<2, 1, 4, 4>, S<4, 1, 16, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 4>, S<4, 1, 2, 4>, S<2, 1, 32, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 4>, S<1, 2, 0, 3>, S<1, 1, 1, 4>, S<0, 1, 2, 3, 4, 5>, 5, 4>, + DeviceBatchedGemmMultipleD_Dl< Col, Col, Empty_Tuple, Row, int8_t, int8_t, int32_t, Empty_Tuple, int8_t, PassThrough, PassThrough, PassThrough, GemmDefault, 64, 64, 64, 8, 4, 4, 4, 1, S<4, 2>, S<8, 1>, S<2, 1, 4, 4>, S<4, 1, 16, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 4>, S<4, 1, 2, 4>, S<2, 1, 32, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 4>, S<1, 2, 0, 3>, S<1, 1, 1, 4>, S<0, 1, 2, 3, 4, 5>, 5, 4>, + // MPerBlock=32, NPerBlock=32 + DeviceBatchedGemmMultipleD_Dl< Col, Col, Empty_Tuple, Row, int8_t, int8_t, int32_t, Empty_Tuple, int8_t, PassThrough, PassThrough, PassThrough, GemmDefault, 32, 32, 32, 8, 4, 4, 2, 1, S<2, 2>, S<2, 4>, S<2, 1, 4, 4>, S<4, 1, 8, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 4>, S<4, 1, 2, 4>, S<2, 1, 16, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 4>, S<1, 2, 0, 3>, S<1, 1, 1, 4>, S<0, 1, 2, 3, 4, 5>, 5, 2>, + // MPerBlock=16, NPerBlock=64 + DeviceBatchedGemmMultipleD_Dl< Col, Col, Empty_Tuple, Row, int8_t, int8_t, int32_t, Empty_Tuple, int8_t, PassThrough, PassThrough, PassThrough, GemmDefault, 64, 16, 64, 16, 2, 1, 4, 1, S<4, 2>, S<4, 2>, S<1, 1, 4, 2>, S<16, 1, 4, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 2>, S<4, 1, 4, 2>, S<4, 1, 16, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<0, 1, 2, 3, 4, 5>, 5, 4>, + DeviceBatchedGemmMultipleD_Dl< Col, Col, Empty_Tuple, Row, int8_t, int8_t, int32_t, Empty_Tuple, int8_t, PassThrough, PassThrough, PassThrough, GemmDefault, 64, 16, 64, 16, 2, 1, 4, 1, S<2, 4>, S<2, 4>, S<1, 1, 4, 2>, S<16, 1, 4, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 2>, S<4, 1, 4, 2>, S<4, 1, 16, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<0, 1, 2, 3, 4, 5>, 5, 4>, + // MPerBlock=64, NPerBlock=16 + DeviceBatchedGemmMultipleD_Dl< Col, Col, Empty_Tuple, Row, int8_t, int8_t, int32_t, Empty_Tuple, int8_t, PassThrough, PassThrough, PassThrough, GemmDefault, 64, 64, 16, 16, 2, 4, 1, 1, S<4, 2>, S<4, 2>, S<4, 1, 4, 2>, S<4, 1, 16, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 2>, S<4, 1, 1, 2>, S<4, 1, 16, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<0, 1, 2, 3, 4, 5>, 5, 1>, + DeviceBatchedGemmMultipleD_Dl< Col, Col, Empty_Tuple, Row, int8_t, int8_t, int32_t, Empty_Tuple, int8_t, PassThrough, PassThrough, PassThrough, GemmDefault, 64, 64, 16, 16, 2, 4, 1, 1, S<2, 4>, S<2, 4>, S<4, 1, 4, 2>, S<4, 1, 16, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 2>, S<4, 1, 1, 2>, S<4, 1, 16, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<0, 1, 2, 3, 4, 5>, 5, 1>, + // MPerBlock=16, NPerBlock=16 + DeviceBatchedGemmMultipleD_Dl< Col, Col, Empty_Tuple, Row, int8_t, int8_t, int32_t, Empty_Tuple, int8_t, PassThrough, PassThrough, PassThrough, GemmDefault, 16, 16, 16, 16, 2, 2, 2, 1, S<4, 1>, S<4, 1>, S<4, 1, 4, 2>, S<4, 1, 4, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 2>, S<4, 1, 4, 2>, S<4, 1, 4, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<0, 1, 2, 3, 4, 5>, 5, 2>, + // MPerBlock=8, NPerBlock=64 + DeviceBatchedGemmMultipleD_Dl< Col, Col, Empty_Tuple, Row, int8_t, int8_t, int32_t, Empty_Tuple, int8_t, PassThrough, PassThrough, PassThrough, GemmDefault, 64, 8, 64, 32, 2, 1, 2, 1, S<4, 1>, S<8, 2>, S<1, 1, 4, 2>, S<32, 1, 2, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 2>, S<8, 1, 4, 2>, S<4, 1, 16, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<0, 1, 2, 3, 4, 5>, 5, 2>, + DeviceBatchedGemmMultipleD_Dl< Col, Col, Empty_Tuple, Row, int8_t, int8_t, int32_t, Empty_Tuple, int8_t, PassThrough, PassThrough, PassThrough, GemmDefault, 64, 8, 64, 32, 2, 1, 2, 1, S<2, 2>, S<8, 2>, S<1, 1, 4, 2>, S<32, 1, 2, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 2>, S<8, 1, 4, 2>, S<4, 1, 16, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<0, 1, 2, 3, 4, 5>, 5, 2>, + // MPerBlock=64, NPerBlock=8 + DeviceBatchedGemmMultipleD_Dl< Col, Col, Empty_Tuple, Row, int8_t, int8_t, int32_t, Empty_Tuple, int8_t, PassThrough, PassThrough, PassThrough, GemmDefault, 64, 64, 8, 32, 2, 2, 1, 1, S<8, 2>, S<4, 1>, S<8, 1, 4, 2>, S<4, 1, 16, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 2>, S<4, 1, 1, 2>, S<8, 1, 8, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<0, 1, 2, 3, 4, 5>, 5, 1>, + DeviceBatchedGemmMultipleD_Dl< Col, Col, Empty_Tuple, Row, int8_t, int8_t, int32_t, Empty_Tuple, int8_t, PassThrough, PassThrough, PassThrough, GemmDefault, 64, 64, 8, 32, 2, 2, 1, 1, S<8, 2>, S<2, 2>, S<8, 1, 4, 2>, S<4, 1, 16, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 2>, S<4, 1, 1, 2>, S<8, 1, 8, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<0, 1, 2, 3, 4, 5>, 5, 1>, + // MPerBlock=8, NPerBlock=8 + DeviceBatchedGemmMultipleD_Dl< Col, Col, Empty_Tuple, Row, int8_t, int8_t, int32_t, Empty_Tuple, int8_t, PassThrough, PassThrough, PassThrough, GemmDefault, 8, 8, 8, 4, 2, 1, 2, 1, S<4, 1>, S<2, 1>, S<1, 1, 4, 2>, S<4, 1, 2, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 2>, S<4, 1, 1, 2>, S<1, 1, 8, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<0, 1, 2, 3, 4, 5>, 5, 2>, + DeviceBatchedGemmMultipleD_Dl< Col, Col, Empty_Tuple, Row, int8_t, int8_t, int32_t, Empty_Tuple, int8_t, PassThrough, PassThrough, PassThrough, GemmDefault, 8, 8, 8, 4, 2, 1, 2, 1, S<1, 4>, S<1, 2>, S<1, 1, 4, 2>, S<4, 1, 2, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 2>, S<4, 1, 1, 2>, S<1, 1, 8, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<0, 1, 2, 3, 4, 5>, 5, 2>, + DeviceBatchedGemmMultipleD_Dl< Col, Col, Empty_Tuple, Row, int8_t, int8_t, int32_t, Empty_Tuple, int8_t, PassThrough, PassThrough, PassThrough, GemmDefault, 8, 8, 8, 4, 2, 2, 1, 1, S<2, 1>, S<4, 1>, S<1, 1, 4, 2>, S<4, 1, 2, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 2>, S<4, 1, 1, 2>, S<1, 1, 8, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<0, 1, 2, 3, 4, 5>, 5, 1>, + DeviceBatchedGemmMultipleD_Dl< Col, Col, Empty_Tuple, Row, int8_t, int8_t, int32_t, Empty_Tuple, int8_t, PassThrough, PassThrough, PassThrough, GemmDefault, 8, 8, 8, 4, 2, 2, 1, 1, S<1, 2>, S<1, 4>, S<1, 1, 4, 2>, S<4, 1, 2, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 2>, S<4, 1, 1, 2>, S<1, 1, 8, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<0, 1, 2, 3, 4, 5>, 5, 1> + // clang-format on + >; + +void add_device_batched_gemm_multi_d_dl_i8_i8_i8_gkm_gnk_gmn_instances( + std::vector>>& instances) +{ + add_device_operation_instances(instances, + device_batched_gemm_multi_d_dl_i8_i8_i8_gkm_gnk_gmn_instances{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/batched_gemm_multi_d/device_batched_gemm_multi_d_dl_i8_i8_i8_gkm_gnk_gmn_irregular_instance.cpp b/library/src/tensor_operation_instance/gpu/batched_gemm_multi_d/device_batched_gemm_multi_d_dl_i8_i8_i8_gkm_gnk_gmn_irregular_instance.cpp new file mode 100644 index 0000000000000000000000000000000000000000..e0ead0ec4940482bbd19a68affd9064d09d9cd83 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/batched_gemm_multi_d/device_batched_gemm_multi_d_dl_i8_i8_i8_gkm_gnk_gmn_irregular_instance.cpp @@ -0,0 +1,90 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/device_batched_gemm_multi_d.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_batched_gemm_multiple_d_dl.hpp" +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +using Row = ck::tensor_layout::gemm::RowMajor; +using Col = ck::tensor_layout::gemm::ColumnMajor; + +template +using S = ck::Sequence; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; +using Empty_Tuple = ck::Tuple<>; + +static constexpr auto GemmMNPadding = ck::tensor_operation::device::GemmSpecialization::MNPadding; + +// Compilation parameters for a[k, m] * b[n, k] = c[m, n] +using device_batched_gemm_multi_d_dl_i8_i8_i8_gkm_gnk_gmn_irregular_instances = std::tuple< + // clang-format off + // ##########################| ALayout| BLayout| DsLayout| CLayout| AData| BData| AccData| DsData| CData| A| B| C| GEMM| Block| MPer| NPer| K0Per| K1| M1Per| N1Per| KPer| M11N11Thread| M11N11Thread| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| CThreadTransfer| CThreadTransfer| CThreadTransfer| + // ##########################| | | | | Type| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise| Specialization| Size| Block| Block| Block| | ThreadM111| ThreadN111| Thread| ClusterM110Xs| ClusterN110Xs| ThreadSliceLengths| ThreadClusterLengths| ThreadCluster| SrcAccess| SrcVectorTensor| SrcVectorTensor| DstVectorTensor| ThreadSliceLengths| ThreadClusterLengths| ThreadCluster| SrcAccess| SrcVectorTensor| SrcVectorTensor| DstVectorTensor| SrcDstAccess| SrcDstVectorDim| DstScalarPerVector| + // ##########################| | | | | | | | | | Operation| Operation| Operation| | | | | | | | | | | | K0_M0_M1_K1| K0_M0_M1_K1| ArrangeOrder| Order| Lengths_K0_M0_M1_K1| ContiguousDimOrder| Lengths_K0_M0_M1_K1| K0_M0_M1_K1| K0_M0_M1_K1| ArrangeOrder| Order| Lengths_K0_M0_M1_K1| ContiguousDimOrder| Lengths_K0_M0_M1_K1| Order| | | + // ##########################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + // MPerBlock=128, NPerBlock=128 + DeviceBatchedGemmMultipleD_Dl< Col, Col, Empty_Tuple, Row, int8_t, int8_t, int32_t, Empty_Tuple, int8_t, PassThrough, PassThrough, PassThrough, GemmMNPadding, 128, 128, 128, 16, 4, 4, 8, 1, S<8, 2>, S<4, 2>, S<2, 1, 8, 4>, S<8, 1, 16, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 4>, S<8, 1, 2, 4>, S<2, 1, 64, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 4>, S<1, 2, 0, 3>, S<1, 1, 1, 4>, S<0, 1, 2, 3, 4, 5>, 5, 1>, + DeviceBatchedGemmMultipleD_Dl< Col, Col, Empty_Tuple, Row, int8_t, int8_t, int32_t, Empty_Tuple, int8_t, PassThrough, PassThrough, PassThrough, GemmMNPadding, 128, 128, 128, 16, 4, 4, 8, 1, S<4, 4>, S<4, 2>, S<2, 1, 8, 4>, S<8, 1, 16, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 4>, S<8, 1, 2, 4>, S<2, 1, 64, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 4>, S<1, 2, 0, 3>, S<1, 1, 1, 4>, S<0, 1, 2, 3, 4, 5>, 5, 1>, + DeviceBatchedGemmMultipleD_Dl< Col, Col, Empty_Tuple, Row, int8_t, int8_t, int32_t, Empty_Tuple, int8_t, PassThrough, PassThrough, PassThrough, GemmMNPadding, 128, 128, 128, 16, 4, 4, 8, 1, S<2, 8>, S<2, 4>, S<2, 1, 8, 4>, S<8, 1, 16, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 4>, S<8, 1, 2, 4>, S<2, 1, 64, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 4>, S<1, 2, 0, 3>, S<1, 1, 1, 4>, S<0, 1, 2, 3, 4, 5>, 5, 1>, + // MPerBlock=128, NPerBlock=64 + DeviceBatchedGemmMultipleD_Dl< Col, Col, Empty_Tuple, Row, int8_t, int8_t, int32_t, Empty_Tuple, int8_t, PassThrough, PassThrough, PassThrough, GemmMNPadding, 256, 128, 64, 16, 4, 4, 2, 1, S<4, 4>, S<4, 4>, S<2, 1, 4, 4>, S<8, 1, 32, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 4>, S<8, 1, 1, 4>, S<2, 1, 64, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 4>, S<1, 2, 0, 3>, S<1, 1, 1, 4>, S<0, 1, 2, 3, 4, 5>, 5, 1>, + DeviceBatchedGemmMultipleD_Dl< Col, Col, Empty_Tuple, Row, int8_t, int8_t, int32_t, Empty_Tuple, int8_t, PassThrough, PassThrough, PassThrough, GemmMNPadding, 256, 128, 64, 16, 4, 4, 2, 1, S<2, 8>, S<2, 8>, S<2, 1, 4, 4>, S<8, 1, 32, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 4>, S<8, 1, 1, 4>, S<2, 1, 64, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 4>, S<1, 2, 0, 3>, S<1, 1, 1, 4>, S<0, 1, 2, 3, 4, 5>, 5, 1>, + // MPerBlock=64, NPerBlock=128 + DeviceBatchedGemmMultipleD_Dl< Col, Col, Empty_Tuple, Row, int8_t, int8_t, int32_t, Empty_Tuple, int8_t, PassThrough, PassThrough, PassThrough, GemmMNPadding, 256, 64, 128, 16, 4, 2, 4, 1, S<4, 4>, S<4, 4>, S<2, 1, 4, 4>, S<8, 1, 16, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 4>, S<8, 1, 1, 4>, S<2, 1, 128, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 4>, S<1, 2, 0, 3>, S<1, 1, 1, 4>, S<0, 1, 2, 3, 4, 5>, 5, 1>, + DeviceBatchedGemmMultipleD_Dl< Col, Col, Empty_Tuple, Row, int8_t, int8_t, int32_t, Empty_Tuple, int8_t, PassThrough, PassThrough, PassThrough, GemmMNPadding, 256, 64, 128, 16, 4, 2, 4, 1, S<2, 8>, S<2, 8>, S<2, 1, 4, 4>, S<8, 1, 16, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 4>, S<8, 1, 1, 4>, S<2, 1, 128, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 4>, S<1, 2, 0, 3>, S<1, 1, 1, 4>, S<0, 1, 2, 3, 4, 5>, 5, 1>, + // MPerBlock=64, NPerBlock=64 + DeviceBatchedGemmMultipleD_Dl< Col, Col, Empty_Tuple, Row, int8_t, int8_t, int32_t, Empty_Tuple, int8_t, PassThrough, PassThrough, PassThrough, GemmMNPadding, 64, 64, 64, 8, 4, 4, 4, 1, S<4, 2>, S<4, 2>, S<2, 1, 4, 4>, S<4, 1, 16, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 4>, S<4, 1, 2, 4>, S<2, 1, 32, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 4>, S<1, 2, 0, 3>, S<1, 1, 1, 4>, S<0, 1, 2, 3, 4, 5>, 5, 1>, + DeviceBatchedGemmMultipleD_Dl< Col, Col, Empty_Tuple, Row, int8_t, int8_t, int32_t, Empty_Tuple, int8_t, PassThrough, PassThrough, PassThrough, GemmMNPadding, 64, 64, 64, 8, 4, 4, 4, 1, S<2, 4>, S<2, 4>, S<2, 1, 4, 4>, S<4, 1, 16, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 4>, S<4, 1, 2, 4>, S<2, 1, 32, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 4>, S<1, 2, 0, 3>, S<1, 1, 1, 4>, S<0, 1, 2, 3, 4, 5>, 5, 1>, + DeviceBatchedGemmMultipleD_Dl< Col, Col, Empty_Tuple, Row, int8_t, int8_t, int32_t, Empty_Tuple, int8_t, PassThrough, PassThrough, PassThrough, GemmMNPadding, 64, 64, 64, 8, 4, 4, 4, 1, S<8, 1>, S<4, 2>, S<2, 1, 4, 4>, S<4, 1, 16, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 4>, S<4, 1, 2, 4>, S<2, 1, 32, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 4>, S<1, 2, 0, 3>, S<1, 1, 1, 4>, S<0, 1, 2, 3, 4, 5>, 5, 1>, + DeviceBatchedGemmMultipleD_Dl< Col, Col, Empty_Tuple, Row, int8_t, int8_t, int32_t, Empty_Tuple, int8_t, PassThrough, PassThrough, PassThrough, GemmMNPadding, 64, 64, 64, 8, 4, 4, 4, 1, S<4, 2>, S<8, 1>, S<2, 1, 4, 4>, S<4, 1, 16, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 4>, S<4, 1, 2, 4>, S<2, 1, 32, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 4>, S<1, 2, 0, 3>, S<1, 1, 1, 4>, S<0, 1, 2, 3, 4, 5>, 5, 1>, + // MPerBlock=32, NPerBlock=32 + DeviceBatchedGemmMultipleD_Dl< Col, Col, Empty_Tuple, Row, int8_t, int8_t, int32_t, Empty_Tuple, int8_t, PassThrough, PassThrough, PassThrough, GemmMNPadding, 32, 32, 32, 8, 4, 2, 4, 1, S<4, 2>, S<2, 2>, S<2, 1, 4, 4>, S<4, 1, 8, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 4>, S<4, 1, 2, 4>, S<2, 1, 16, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 4>, S<1, 2, 0, 3>, S<1, 1, 1, 4>, S<0, 1, 2, 3, 4, 5>, 5, 1>, + DeviceBatchedGemmMultipleD_Dl< Col, Col, Empty_Tuple, Row, int8_t, int8_t, int32_t, Empty_Tuple, int8_t, PassThrough, PassThrough, PassThrough, GemmMNPadding, 32, 32, 32, 8, 4, 4, 2, 1, S<2, 2>, S<4, 2>, S<2, 1, 4, 4>, S<4, 1, 8, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 4>, S<4, 1, 2, 4>, S<2, 1, 16, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 4>, S<1, 2, 0, 3>, S<1, 1, 1, 4>, S<0, 1, 2, 3, 4, 5>, 5, 1>, + DeviceBatchedGemmMultipleD_Dl< Col, Col, Empty_Tuple, Row, int8_t, int8_t, int32_t, Empty_Tuple, int8_t, PassThrough, PassThrough, PassThrough, GemmMNPadding, 32, 32, 32, 8, 4, 4, 2, 1, S<2, 2>, S<2, 4>, S<2, 1, 4, 4>, S<4, 1, 8, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 4>, S<4, 1, 2, 4>, S<2, 1, 16, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 4>, S<1, 2, 0, 3>, S<1, 1, 1, 4>, S<0, 1, 2, 3, 4, 5>, 5, 1>, + // MPerBlock=16, NPerBlock=16 + DeviceBatchedGemmMultipleD_Dl< Col, Col, Empty_Tuple, Row, int8_t, int8_t, int32_t, Empty_Tuple, int8_t, PassThrough, PassThrough, PassThrough, GemmMNPadding, 16, 16, 16, 16, 2, 2, 2, 1, S<2, 2>, S<2, 2>, S<4, 1, 4, 2>, S<4, 1, 4, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 2>, S<4, 1, 4, 2>, S<4, 1, 4, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<0, 1, 2, 3, 4, 5>, 5, 1>, + DeviceBatchedGemmMultipleD_Dl< Col, Col, Empty_Tuple, Row, int8_t, int8_t, int32_t, Empty_Tuple, int8_t, PassThrough, PassThrough, PassThrough, GemmMNPadding, 16, 16, 16, 16, 2, 2, 2, 1, S<4, 1>, S<4, 1>, S<4, 1, 4, 2>, S<4, 1, 4, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 2>, S<4, 1, 4, 2>, S<4, 1, 4, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<0, 1, 2, 3, 4, 5>, 5, 1>, + // MPerBlock=8, NPerBlock=64 + DeviceBatchedGemmMultipleD_Dl< Col, Col, Empty_Tuple, Row, int8_t, int8_t, int32_t, Empty_Tuple, int8_t, PassThrough, PassThrough, PassThrough, GemmMNPadding, 64, 8, 64, 32, 2, 1, 2, 1, S<2, 2>, S<8, 2>, S<1, 1, 4, 2>, S<32, 1, 2, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 2>, S<8, 1, 4, 2>, S<4, 1, 16, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<0, 1, 2, 3, 4, 5>, 5, 1>, + // MPerBlock=64, NPerBlock=8 + DeviceBatchedGemmMultipleD_Dl< Col, Col, Empty_Tuple, Row, int8_t, int8_t, int32_t, Empty_Tuple, int8_t, PassThrough, PassThrough, PassThrough, GemmMNPadding, 64, 64, 8, 32, 2, 2, 1, 1, S<8, 2>, S<2, 2>, S<8, 1, 4, 2>, S<4, 1, 16, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 2>, S<4, 1, 1, 2>, S<8, 1, 8, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<0, 1, 2, 3, 4, 5>, 5, 1>, + // MPerBlock=8, NPerBlock=8 + DeviceBatchedGemmMultipleD_Dl< Col, Col, Empty_Tuple, Row, int8_t, int8_t, int32_t, Empty_Tuple, int8_t, PassThrough, PassThrough, PassThrough, GemmMNPadding, 8, 8, 8, 4, 2, 1, 2, 1, S<4, 1>, S<2, 1>, S<1, 1, 4, 2>, S<4, 1, 2, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 2>, S<4, 1, 1, 2>, S<1, 1, 8, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<0, 1, 2, 3, 4, 5>, 5, 1>, + DeviceBatchedGemmMultipleD_Dl< Col, Col, Empty_Tuple, Row, int8_t, int8_t, int32_t, Empty_Tuple, int8_t, PassThrough, PassThrough, PassThrough, GemmMNPadding, 8, 8, 8, 4, 2, 1, 2, 1, S<1, 4>, S<1, 2>, S<1, 1, 4, 2>, S<4, 1, 2, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 2>, S<4, 1, 1, 2>, S<1, 1, 8, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<0, 1, 2, 3, 4, 5>, 5, 1>, + DeviceBatchedGemmMultipleD_Dl< Col, Col, Empty_Tuple, Row, int8_t, int8_t, int32_t, Empty_Tuple, int8_t, PassThrough, PassThrough, PassThrough, GemmMNPadding, 8, 8, 8, 4, 2, 2, 1, 1, S<2, 1>, S<4, 1>, S<1, 1, 4, 2>, S<4, 1, 2, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 2>, S<4, 1, 1, 2>, S<1, 1, 8, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<0, 1, 2, 3, 4, 5>, 5, 1>, + DeviceBatchedGemmMultipleD_Dl< Col, Col, Empty_Tuple, Row, int8_t, int8_t, int32_t, Empty_Tuple, int8_t, PassThrough, PassThrough, PassThrough, GemmMNPadding, 8, 8, 8, 4, 2, 2, 1, 1, S<1, 2>, S<1, 4>, S<1, 1, 4, 2>, S<4, 1, 2, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 2>, S<4, 1, 1, 2>, S<1, 1, 8, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<0, 1, 2, 3, 4, 5>, 5, 1> + // clang-format on + >; + +void add_device_batched_gemm_multi_d_dl_i8_i8_i8_gkm_gnk_gmn_irregular_instances( + std::vector>>& instances) +{ + add_device_operation_instances( + instances, device_batched_gemm_multi_d_dl_i8_i8_i8_gkm_gnk_gmn_irregular_instances{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/batched_gemm_multi_d/device_batched_gemm_multi_d_dl_i8_i8_i8_gmk_gkn_gmn_instance.cpp b/library/src/tensor_operation_instance/gpu/batched_gemm_multi_d/device_batched_gemm_multi_d_dl_i8_i8_i8_gmk_gkn_gmn_instance.cpp new file mode 100644 index 0000000000000000000000000000000000000000..55fa86f37ff58b82a881cce16750cb78c19e828d --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/batched_gemm_multi_d/device_batched_gemm_multi_d_dl_i8_i8_i8_gmk_gkn_gmn_instance.cpp @@ -0,0 +1,93 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/device_batched_gemm_multi_d.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_batched_gemm_multiple_d_dl.hpp" +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +using Row = ck::tensor_layout::gemm::RowMajor; +using Col = ck::tensor_layout::gemm::ColumnMajor; + +template +using S = ck::Sequence; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; +using Empty_Tuple = ck::Tuple<>; + +static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default; + +// Compilation parameters for a[m, k] * b[k, n] = c[m, n] +using device_batched_gemm_multi_d_dl_i8_i8_i8_gmk_gkn_gmn_instances = std::tuple< + // clang-format off + // ##########################| ALayout| BLayout| DsLayout| CLayout| AData| BData| AccData| DsData| CData| A| B| C| GEMM| Block| MPer| NPer| K0Per| K1| M1Per| N1Per| KPer| M11N11Thread| M11N11Thread| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| CThreadTransfer| CThreadTransfer| CThreadTransfer| + // ##########################| | | | | Type| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise| Specialization| Size| Block| Block| Block| | ThreadM111| ThreadN111| Thread| ClusterM110Xs| ClusterN110Xs| ThreadSliceLengths| ThreadClusterLengths| ThreadCluster| SrcAccess| SrcVectorTensor| SrcVectorTensor| DstVectorTensor| ThreadSliceLengths| ThreadClusterLengths| ThreadCluster| SrcAccess| SrcVectorTensor| SrcVectorTensor| DstVectorTensor| SrcDstAccess| SrcDstVectorDim| DstScalarPerVector| + // ##########################| | | | | | | | | | Operation| Operation| Operation| | | | | | | | | | | | K0_M0_M1_K1| K0_M0_M1_K1| ArrangeOrder| Order| Lengths_K0_M0_M1_K1| ContiguousDimOrder| Lengths_K0_M0_M1_K1| K0_M0_M1_K1| K0_M0_M1_K1| ArrangeOrder| Order| Lengths_K0_M0_M1_K1| ContiguousDimOrder| Lengths_K0_M0_M1_K1| Order| | | + // ##########################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + // MPerBlock=128, NPerBlock=128 + DeviceBatchedGemmMultipleD_Dl< Row, Row, Empty_Tuple, Row, int8_t, int8_t, int32_t, Empty_Tuple, int8_t, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 128, 128, 16, 4, 4, 4, 1, S<2, 8>, S<2, 8>, S<8, 1, 1, 4>, S<2, 1, 128, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 4>, S<1, 2, 0, 3>, S<1, 1, 1, 4>, S<2, 1, 4, 4>, S<8, 1, 32, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 4>, S<0, 1, 2, 3, 4, 5>, 5, 4>, + DeviceBatchedGemmMultipleD_Dl< Row, Row, Empty_Tuple, Row, int8_t, int8_t, int32_t, Empty_Tuple, int8_t, PassThrough, PassThrough, PassThrough, GemmDefault, 128, 128, 128, 16, 4, 4, 8, 1, S<8, 2>, S<4, 2>, S<8, 1, 2, 4>, S<2, 1, 64, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 4>, S<1, 2, 0, 3>, S<1, 1, 1, 4>, S<2, 1, 8, 4>, S<8, 1, 16, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 4>, S<0, 1, 2, 3, 4, 5>, 5, 4>, + DeviceBatchedGemmMultipleD_Dl< Row, Row, Empty_Tuple, Row, int8_t, int8_t, int32_t, Empty_Tuple, int8_t, PassThrough, PassThrough, PassThrough, GemmDefault, 128, 128, 128, 16, 4, 4, 8, 1, S<2, 8>, S<2, 4>, S<8, 1, 2, 4>, S<2, 1, 64, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 4>, S<1, 2, 0, 3>, S<1, 1, 1, 4>, S<2, 1, 8, 4>, S<8, 1, 16, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 4>, S<0, 1, 2, 3, 4, 5>, 5, 4>, + // MPerBlock=128, NPerBlock=64 + DeviceBatchedGemmMultipleD_Dl< Row, Row, Empty_Tuple, Row, int8_t, int8_t, int32_t, Empty_Tuple, int8_t, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 128, 64, 16, 4, 4, 2, 1, S<2, 8>, S<2, 8>, S<8, 1, 1, 4>, S<2, 1, 128, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 4>, S<1, 2, 0, 3>, S<1, 1, 1, 4>, S<2, 1, 4, 4>, S<8, 1, 16, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 4>, S<0, 1, 2, 3, 4, 5>, 5, 2>, + // MPerBlock=64, NPerBlock=128 + DeviceBatchedGemmMultipleD_Dl< Row, Row, Empty_Tuple, Row, int8_t, int8_t, int32_t, Empty_Tuple, int8_t, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 64, 128, 16, 4, 2, 4, 1, S<4, 4>, S<4, 4>, S<8, 1, 1, 4>, S<2, 1, 64, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 4>, S<1, 2, 0, 3>, S<1, 1, 1, 4>, S<2, 1, 4, 4>, S<8, 1, 32, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 4>, S<0, 1, 2, 3, 4, 5>, 5, 4>, + // MPerBlock=64, NPerBlock=64 + DeviceBatchedGemmMultipleD_Dl< Row, Row, Empty_Tuple, Row, int8_t, int8_t, int32_t, Empty_Tuple, int8_t, PassThrough, PassThrough, PassThrough, GemmDefault, 64, 64, 64, 8, 4, 4, 4, 1, S<4, 2>, S<4, 2>, S<4, 1, 2, 4>, S<2, 1, 32, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 4>, S<1, 2, 0, 3>, S<1, 1, 1, 4>, S<2, 1, 4, 4>, S<4, 1, 16, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 4>, S<0, 1, 2, 3, 4, 5>, 5, 4>, + DeviceBatchedGemmMultipleD_Dl< Row, Row, Empty_Tuple, Row, int8_t, int8_t, int32_t, Empty_Tuple, int8_t, PassThrough, PassThrough, PassThrough, GemmDefault, 64, 64, 64, 8, 4, 4, 4, 1, S<2, 4>, S<2, 4>, S<4, 1, 2, 4>, S<2, 1, 32, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 4>, S<1, 2, 0, 3>, S<1, 1, 1, 4>, S<2, 1, 4, 4>, S<4, 1, 16, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 4>, S<0, 1, 2, 3, 4, 5>, 5, 4>, + DeviceBatchedGemmMultipleD_Dl< Row, Row, Empty_Tuple, Row, int8_t, int8_t, int32_t, Empty_Tuple, int8_t, PassThrough, PassThrough, PassThrough, GemmDefault, 64, 64, 64, 8, 4, 4, 4, 1, S<8, 1>, S<4, 2>, S<4, 1, 2, 4>, S<2, 1, 32, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 4>, S<1, 2, 0, 3>, S<1, 1, 1, 4>, S<2, 1, 4, 4>, S<4, 1, 16, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 4>, S<0, 1, 2, 3, 4, 5>, 5, 4>, + DeviceBatchedGemmMultipleD_Dl< Row, Row, Empty_Tuple, Row, int8_t, int8_t, int32_t, Empty_Tuple, int8_t, PassThrough, PassThrough, PassThrough, GemmDefault, 64, 64, 64, 8, 4, 4, 4, 1, S<4, 2>, S<8, 1>, S<4, 1, 2, 4>, S<2, 1, 32, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 4>, S<1, 2, 0, 3>, S<1, 1, 1, 4>, S<2, 1, 4, 4>, S<4, 1, 16, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 4>, S<0, 1, 2, 3, 4, 5>, 5, 4>, + // MPerBlock=32, NPerBlock=32 + DeviceBatchedGemmMultipleD_Dl< Row, Row, Empty_Tuple, Row, int8_t, int8_t, int32_t, Empty_Tuple, int8_t, PassThrough, PassThrough, PassThrough, GemmDefault, 32, 32, 32, 8, 4, 4, 2, 1, S<2, 2>, S<2, 4>, S<4, 1, 2, 4>, S<2, 1, 16, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 4>, S<1, 2, 0, 3>, S<1, 1, 1, 4>, S<2, 1, 4, 4>, S<4, 1, 8, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 4>, S<0, 1, 2, 3, 4, 5>, 5, 2>, + // MPerBlock=16, NPerBlock=64 + DeviceBatchedGemmMultipleD_Dl< Row, Row, Empty_Tuple, Row, int8_t, int8_t, int32_t, Empty_Tuple, int8_t, PassThrough, PassThrough, PassThrough, GemmDefault, 64, 16, 64, 16, 2, 1, 4, 1, S<4, 2>, S<4, 2>, S<4, 1, 1, 2>, S<4, 1, 16, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<4, 1, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 1, 2>, S<4, 1, 4, 2>, S<4, 1, 16, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 2>, S<0, 1, 2, 3, 4, 5>, 5, 4>, + DeviceBatchedGemmMultipleD_Dl< Row, Row, Empty_Tuple, Row, int8_t, int8_t, int32_t, Empty_Tuple, int8_t, PassThrough, PassThrough, PassThrough, GemmDefault, 64, 16, 64, 16, 2, 1, 4, 1, S<2, 4>, S<2, 4>, S<4, 1, 1, 2>, S<4, 1, 16, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<4, 1, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 1, 2>, S<4, 1, 4, 2>, S<4, 1, 16, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 2>, S<0, 1, 2, 3, 4, 5>, 5, 4>, + // MPerBlock=64, NPerBlock=16 + DeviceBatchedGemmMultipleD_Dl< Row, Row, Empty_Tuple, Row, int8_t, int8_t, int32_t, Empty_Tuple, int8_t, PassThrough, PassThrough, PassThrough, GemmDefault, 64, 64, 16, 16, 2, 4, 1, 1, S<4, 2>, S<4, 2>, S<4, 1, 4, 2>, S<4, 1, 16, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<4, 1, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 1, 2>, S<1, 1, 4, 2>, S<16, 1, 4, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 2>, S<0, 1, 2, 3, 4, 5>, 5, 1>, + DeviceBatchedGemmMultipleD_Dl< Row, Row, Empty_Tuple, Row, int8_t, int8_t, int32_t, Empty_Tuple, int8_t, PassThrough, PassThrough, PassThrough, GemmDefault, 64, 64, 16, 16, 2, 4, 1, 1, S<2, 4>, S<2, 4>, S<4, 1, 4, 2>, S<4, 1, 16, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<4, 1, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 1, 2>, S<1, 1, 4, 2>, S<16, 1, 4, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 2>, S<0, 1, 2, 3, 4, 5>, 5, 1>, + // MPerBlock=16, NPerBlock=16 + DeviceBatchedGemmMultipleD_Dl< Row, Row, Empty_Tuple, Row, int8_t, int8_t, int32_t, Empty_Tuple, int8_t, PassThrough, PassThrough, PassThrough, GemmDefault, 16, 16, 16, 16, 2, 2, 2, 1, S<4, 1>, S<4, 1>, S<4, 1, 4, 2>, S<4, 1, 4, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<4, 1, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 1, 2>, S<4, 1, 4, 2>, S<4, 1, 4, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 2>, S<0, 1, 2, 3, 4, 5>, 5, 2>, + // MPerBlock=8, NPerBlock=64 + DeviceBatchedGemmMultipleD_Dl< Row, Row, Empty_Tuple, Row, int8_t, int8_t, int32_t, Empty_Tuple, int8_t, PassThrough, PassThrough, PassThrough, GemmDefault, 64, 8, 64, 32, 2, 1, 2, 1, S<4, 1>, S<8, 2>, S<4, 1, 1, 2>, S<8, 1, 8, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<4, 1, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 1, 2>, S<8, 1, 4, 2>, S<4, 1, 16, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 2>, S<0, 1, 2, 3, 4, 5>, 5, 2>, + DeviceBatchedGemmMultipleD_Dl< Row, Row, Empty_Tuple, Row, int8_t, int8_t, int32_t, Empty_Tuple, int8_t, PassThrough, PassThrough, PassThrough, GemmDefault, 64, 8, 64, 32, 2, 1, 2, 1, S<2, 2>, S<8, 2>, S<4, 1, 1, 2>, S<8, 1, 8, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<4, 1, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 1, 2>, S<8, 1, 4, 2>, S<4, 1, 16, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 2>, S<0, 1, 2, 3, 4, 5>, 5, 2>, + // MPerBlock=64, NPerBlock=8 + DeviceBatchedGemmMultipleD_Dl< Row, Row, Empty_Tuple, Row, int8_t, int8_t, int32_t, Empty_Tuple, int8_t, PassThrough, PassThrough, PassThrough, GemmDefault, 64, 64, 8, 32, 2, 2, 1, 1, S<8, 2>, S<4, 1>, S<8, 1, 4, 2>, S<4, 1, 16, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<4, 1, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 1, 2>, S<1, 1, 4, 2>, S<32, 1, 2, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 2>, S<0, 1, 2, 3, 4, 5>, 5, 1>, + DeviceBatchedGemmMultipleD_Dl< Row, Row, Empty_Tuple, Row, int8_t, int8_t, int32_t, Empty_Tuple, int8_t, PassThrough, PassThrough, PassThrough, GemmDefault, 64, 64, 8, 32, 2, 2, 1, 1, S<8, 2>, S<2, 2>, S<8, 1, 4, 2>, S<4, 1, 16, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<4, 1, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 1, 2>, S<1, 1, 4, 2>, S<32, 1, 2, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 2>, S<0, 1, 2, 3, 4, 5>, 5, 1>, + // MPerBlock=8, NPerBlock=8 + DeviceBatchedGemmMultipleD_Dl< Row, Row, Empty_Tuple, Row, int8_t, int8_t, int32_t, Empty_Tuple, int8_t, PassThrough, PassThrough, PassThrough, GemmDefault, 8, 8, 8, 4, 2, 1, 2, 1, S<4, 1>, S<2, 1>, S<4, 1, 1, 2>, S<1, 1, 8, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<4, 1, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 1, 2>, S<1, 1, 4, 2>, S<4, 1, 2, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 2>, S<0, 1, 2, 3, 4, 5>, 5, 2>, + DeviceBatchedGemmMultipleD_Dl< Row, Row, Empty_Tuple, Row, int8_t, int8_t, int32_t, Empty_Tuple, int8_t, PassThrough, PassThrough, PassThrough, GemmDefault, 8, 8, 8, 4, 2, 1, 2, 1, S<1, 4>, S<1, 2>, S<4, 1, 1, 2>, S<1, 1, 8, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<4, 1, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 1, 2>, S<1, 1, 4, 2>, S<4, 1, 2, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 2>, S<0, 1, 2, 3, 4, 5>, 5, 2>, + DeviceBatchedGemmMultipleD_Dl< Row, Row, Empty_Tuple, Row, int8_t, int8_t, int32_t, Empty_Tuple, int8_t, PassThrough, PassThrough, PassThrough, GemmDefault, 8, 8, 8, 4, 2, 2, 1, 1, S<2, 1>, S<4, 1>, S<4, 1, 1, 2>, S<1, 1, 8, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<4, 1, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 1, 2>, S<1, 1, 4, 2>, S<4, 1, 2, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 2>, S<0, 1, 2, 3, 4, 5>, 5, 1>, + DeviceBatchedGemmMultipleD_Dl< Row, Row, Empty_Tuple, Row, int8_t, int8_t, int32_t, Empty_Tuple, int8_t, PassThrough, PassThrough, PassThrough, GemmDefault, 8, 8, 8, 4, 2, 2, 1, 1, S<1, 2>, S<1, 4>, S<4, 1, 1, 2>, S<1, 1, 8, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<4, 1, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 1, 2>, S<1, 1, 4, 2>, S<4, 1, 2, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 2>, S<0, 1, 2, 3, 4, 5>, 5, 1> + // clang-format on + >; + +void add_device_batched_gemm_multi_d_dl_i8_i8_i8_gmk_gkn_gmn_instances( + std::vector>>& instances) +{ + add_device_operation_instances(instances, + device_batched_gemm_multi_d_dl_i8_i8_i8_gmk_gkn_gmn_instances{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/batched_gemm_multi_d/device_batched_gemm_multi_d_dl_i8_i8_i8_gmk_gkn_gmn_irregular_instance.cpp b/library/src/tensor_operation_instance/gpu/batched_gemm_multi_d/device_batched_gemm_multi_d_dl_i8_i8_i8_gmk_gkn_gmn_irregular_instance.cpp new file mode 100644 index 0000000000000000000000000000000000000000..b8af3e61a3ae6ce06f50ed475d4f40162839e851 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/batched_gemm_multi_d/device_batched_gemm_multi_d_dl_i8_i8_i8_gmk_gkn_gmn_irregular_instance.cpp @@ -0,0 +1,90 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/device_batched_gemm_multi_d.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_batched_gemm_multiple_d_dl.hpp" +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +using Row = ck::tensor_layout::gemm::RowMajor; +using Col = ck::tensor_layout::gemm::ColumnMajor; + +template +using S = ck::Sequence; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; +using Empty_Tuple = ck::Tuple<>; + +static constexpr auto GemmMNPadding = ck::tensor_operation::device::GemmSpecialization::MNPadding; + +// Compilation parameters for a[m, k] * b[k, n] = c[m, n] +using device_batched_gemm_multi_d_dl_i8_i8_i8_gmk_gkn_gmn_irregular_instances = std::tuple< + // clang-format off + // ##########################| ALayout| BLayout| DsLayout| CLayout| AData| BData| AccData| DsData| CData| A| B| C| GEMM| Block| MPer| NPer| K0Per| K1| M1Per| N1Per| KPer| M11N11Thread| M11N11Thread| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| CThreadTransfer| CThreadTransfer| CThreadTransfer| + // ##########################| | | | | Type| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise| Specialization| Size| Block| Block| Block| | ThreadM111| ThreadN111| Thread| ClusterM110Xs| ClusterN110Xs| ThreadSliceLengths| ThreadClusterLengths| ThreadCluster| SrcAccess| SrcVectorTensor| SrcVectorTensor| DstVectorTensor| ThreadSliceLengths| ThreadClusterLengths| ThreadCluster| SrcAccess| SrcVectorTensor| SrcVectorTensor| DstVectorTensor| SrcDstAccess| SrcDstVectorDim| DstScalarPerVector| + // ##########################| | | | | | | | | | Operation| Operation| Operation| | | | | | | | | | | | K0_M0_M1_K1| K0_M0_M1_K1| ArrangeOrder| Order| Lengths_K0_M0_M1_K1| ContiguousDimOrder| Lengths_K0_M0_M1_K1| K0_M0_M1_K1| K0_M0_M1_K1| ArrangeOrder| Order| Lengths_K0_M0_M1_K1| ContiguousDimOrder| Lengths_K0_M0_M1_K1| Order| | | + // ##########################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + // MPerBlock=128, NPerBlock=128 + DeviceBatchedGemmMultipleD_Dl< Row, Row, Empty_Tuple, Row, int8_t, int8_t, int32_t, Empty_Tuple, int8_t, PassThrough, PassThrough, PassThrough, GemmMNPadding, 128, 128, 128, 16, 4, 4, 8, 1, S<8, 2>, S<4, 2>, S<8, 1, 2, 4>, S<2, 1, 64, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 4>, S<1, 2, 0, 3>, S<1, 1, 1, 4>, S<2, 1, 8, 4>, S<8, 1, 16, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 4>, S<0, 1, 2, 3, 4, 5>, 5, 1>, + DeviceBatchedGemmMultipleD_Dl< Row, Row, Empty_Tuple, Row, int8_t, int8_t, int32_t, Empty_Tuple, int8_t, PassThrough, PassThrough, PassThrough, GemmMNPadding, 128, 128, 128, 16, 4, 4, 8, 1, S<4, 4>, S<4, 2>, S<8, 1, 2, 4>, S<2, 1, 64, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 4>, S<1, 2, 0, 3>, S<1, 1, 1, 4>, S<2, 1, 8, 4>, S<8, 1, 16, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 4>, S<0, 1, 2, 3, 4, 5>, 5, 1>, + DeviceBatchedGemmMultipleD_Dl< Row, Row, Empty_Tuple, Row, int8_t, int8_t, int32_t, Empty_Tuple, int8_t, PassThrough, PassThrough, PassThrough, GemmMNPadding, 128, 128, 128, 16, 4, 4, 8, 1, S<2, 8>, S<2, 4>, S<8, 1, 2, 4>, S<2, 1, 64, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 4>, S<1, 2, 0, 3>, S<1, 1, 1, 4>, S<2, 1, 8, 4>, S<8, 1, 16, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 4>, S<0, 1, 2, 3, 4, 5>, 5, 1>, + // MPerBlock=128, NPerBlock=64 + DeviceBatchedGemmMultipleD_Dl< Row, Row, Empty_Tuple, Row, int8_t, int8_t, int32_t, Empty_Tuple, int8_t, PassThrough, PassThrough, PassThrough, GemmMNPadding, 256, 128, 64, 16, 4, 4, 2, 1, S<4, 4>, S<4, 4>, S<8, 1, 1, 4>, S<2, 1, 128, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 4>, S<1, 2, 0, 3>, S<1, 1, 1, 4>, S<2, 1, 4, 4>, S<8, 1, 16, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 4>, S<0, 1, 2, 3, 4, 5>, 5, 1>, + DeviceBatchedGemmMultipleD_Dl< Row, Row, Empty_Tuple, Row, int8_t, int8_t, int32_t, Empty_Tuple, int8_t, PassThrough, PassThrough, PassThrough, GemmMNPadding, 256, 128, 64, 16, 4, 4, 2, 1, S<2, 8>, S<2, 8>, S<8, 1, 1, 4>, S<2, 1, 128, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 4>, S<1, 2, 0, 3>, S<1, 1, 1, 4>, S<2, 1, 4, 4>, S<8, 1, 16, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 4>, S<0, 1, 2, 3, 4, 5>, 5, 1>, + // MPerBlock=64, NPerBlock=128 + DeviceBatchedGemmMultipleD_Dl< Row, Row, Empty_Tuple, Row, int8_t, int8_t, int32_t, Empty_Tuple, int8_t, PassThrough, PassThrough, PassThrough, GemmMNPadding, 256, 64, 128, 16, 4, 2, 4, 1, S<4, 4>, S<4, 4>, S<8, 1, 1, 4>, S<2, 1, 64, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 4>, S<1, 2, 0, 3>, S<1, 1, 1, 4>, S<2, 1, 4, 4>, S<8, 1, 32, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 4>, S<0, 1, 2, 3, 4, 5>, 5, 1>, + DeviceBatchedGemmMultipleD_Dl< Row, Row, Empty_Tuple, Row, int8_t, int8_t, int32_t, Empty_Tuple, int8_t, PassThrough, PassThrough, PassThrough, GemmMNPadding, 256, 64, 128, 16, 4, 2, 4, 1, S<2, 8>, S<2, 8>, S<8, 1, 1, 4>, S<2, 1, 64, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 4>, S<1, 2, 0, 3>, S<1, 1, 1, 4>, S<2, 1, 4, 4>, S<8, 1, 32, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 4>, S<0, 1, 2, 3, 4, 5>, 5, 1>, + // MPerBlock=64, NPerBlock=64 + DeviceBatchedGemmMultipleD_Dl< Row, Row, Empty_Tuple, Row, int8_t, int8_t, int32_t, Empty_Tuple, int8_t, PassThrough, PassThrough, PassThrough, GemmMNPadding, 64, 64, 64, 8, 4, 4, 4, 1, S<4, 2>, S<4, 2>, S<4, 1, 2, 4>, S<2, 1, 32, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 4>, S<1, 2, 0, 3>, S<1, 1, 1, 4>, S<2, 1, 4, 4>, S<4, 1, 16, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 4>, S<0, 1, 2, 3, 4, 5>, 5, 1>, + DeviceBatchedGemmMultipleD_Dl< Row, Row, Empty_Tuple, Row, int8_t, int8_t, int32_t, Empty_Tuple, int8_t, PassThrough, PassThrough, PassThrough, GemmMNPadding, 64, 64, 64, 8, 4, 4, 4, 1, S<2, 4>, S<2, 4>, S<4, 1, 2, 4>, S<2, 1, 32, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 4>, S<1, 2, 0, 3>, S<1, 1, 1, 4>, S<2, 1, 4, 4>, S<4, 1, 16, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 4>, S<0, 1, 2, 3, 4, 5>, 5, 1>, + DeviceBatchedGemmMultipleD_Dl< Row, Row, Empty_Tuple, Row, int8_t, int8_t, int32_t, Empty_Tuple, int8_t, PassThrough, PassThrough, PassThrough, GemmMNPadding, 64, 64, 64, 8, 4, 4, 4, 1, S<8, 1>, S<4, 2>, S<4, 1, 2, 4>, S<2, 1, 32, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 4>, S<1, 2, 0, 3>, S<1, 1, 1, 4>, S<2, 1, 4, 4>, S<4, 1, 16, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 4>, S<0, 1, 2, 3, 4, 5>, 5, 1>, + DeviceBatchedGemmMultipleD_Dl< Row, Row, Empty_Tuple, Row, int8_t, int8_t, int32_t, Empty_Tuple, int8_t, PassThrough, PassThrough, PassThrough, GemmMNPadding, 64, 64, 64, 8, 4, 4, 4, 1, S<4, 2>, S<8, 1>, S<4, 1, 2, 4>, S<2, 1, 32, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 4>, S<1, 2, 0, 3>, S<1, 1, 1, 4>, S<2, 1, 4, 4>, S<4, 1, 16, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 4>, S<0, 1, 2, 3, 4, 5>, 5, 1>, + // MPerBlock=32, NPerBlock=32 + DeviceBatchedGemmMultipleD_Dl< Row, Row, Empty_Tuple, Row, int8_t, int8_t, int32_t, Empty_Tuple, int8_t, PassThrough, PassThrough, PassThrough, GemmMNPadding, 32, 32, 32, 8, 4, 2, 4, 1, S<4, 2>, S<2, 2>, S<4, 1, 2, 4>, S<2, 1, 16, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 4>, S<1, 2, 0, 3>, S<1, 1, 1, 4>, S<2, 1, 4, 4>, S<4, 1, 8, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 4>, S<0, 1, 2, 3, 4, 5>, 5, 1>, + DeviceBatchedGemmMultipleD_Dl< Row, Row, Empty_Tuple, Row, int8_t, int8_t, int32_t, Empty_Tuple, int8_t, PassThrough, PassThrough, PassThrough, GemmMNPadding, 32, 32, 32, 8, 4, 4, 2, 1, S<2, 2>, S<4, 2>, S<4, 1, 2, 4>, S<2, 1, 16, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 4>, S<1, 2, 0, 3>, S<1, 1, 1, 4>, S<2, 1, 4, 4>, S<4, 1, 8, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 4>, S<0, 1, 2, 3, 4, 5>, 5, 1>, + DeviceBatchedGemmMultipleD_Dl< Row, Row, Empty_Tuple, Row, int8_t, int8_t, int32_t, Empty_Tuple, int8_t, PassThrough, PassThrough, PassThrough, GemmMNPadding, 32, 32, 32, 8, 4, 4, 2, 1, S<2, 2>, S<2, 4>, S<4, 1, 2, 4>, S<2, 1, 16, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 4>, S<1, 2, 0, 3>, S<1, 1, 1, 4>, S<2, 1, 4, 4>, S<4, 1, 8, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 4>, S<0, 1, 2, 3, 4, 5>, 5, 1>, + // MPerBlock=16, NPerBlock=16 + DeviceBatchedGemmMultipleD_Dl< Row, Row, Empty_Tuple, Row, int8_t, int8_t, int32_t, Empty_Tuple, int8_t, PassThrough, PassThrough, PassThrough, GemmMNPadding, 16, 16, 16, 16, 2, 2, 2, 1, S<2, 2>, S<2, 2>, S<4, 1, 4, 2>, S<4, 1, 4, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<4, 1, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 1, 2>, S<4, 1, 4, 2>, S<4, 1, 4, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 2>, S<0, 1, 2, 3, 4, 5>, 5, 1>, + DeviceBatchedGemmMultipleD_Dl< Row, Row, Empty_Tuple, Row, int8_t, int8_t, int32_t, Empty_Tuple, int8_t, PassThrough, PassThrough, PassThrough, GemmMNPadding, 16, 16, 16, 16, 2, 2, 2, 1, S<4, 1>, S<4, 1>, S<4, 1, 4, 2>, S<4, 1, 4, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<4, 1, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 1, 2>, S<4, 1, 4, 2>, S<4, 1, 4, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 2>, S<0, 1, 2, 3, 4, 5>, 5, 1>, + // MPerBlock=8, NPerBlock=64 + DeviceBatchedGemmMultipleD_Dl< Row, Row, Empty_Tuple, Row, int8_t, int8_t, int32_t, Empty_Tuple, int8_t, PassThrough, PassThrough, PassThrough, GemmMNPadding, 64, 8, 64, 32, 2, 1, 2, 1, S<2, 2>, S<8, 2>, S<4, 1, 1, 2>, S<8, 1, 8, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<4, 1, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 1, 2>, S<8, 1, 4, 2>, S<4, 1, 16, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 2>, S<0, 1, 2, 3, 4, 5>, 5, 1>, + // MPerBlock=64, NPerBlock=8 + DeviceBatchedGemmMultipleD_Dl< Row, Row, Empty_Tuple, Row, int8_t, int8_t, int32_t, Empty_Tuple, int8_t, PassThrough, PassThrough, PassThrough, GemmMNPadding, 64, 64, 8, 32, 2, 2, 1, 1, S<8, 2>, S<2, 2>, S<8, 1, 4, 2>, S<4, 1, 16, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<4, 1, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 1, 2>, S<1, 1, 4, 2>, S<32, 1, 2, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 2>, S<0, 1, 2, 3, 4, 5>, 5, 1>, + // MPerBlock=8, NPerBlock=8 + DeviceBatchedGemmMultipleD_Dl< Row, Row, Empty_Tuple, Row, int8_t, int8_t, int32_t, Empty_Tuple, int8_t, PassThrough, PassThrough, PassThrough, GemmMNPadding, 8, 8, 8, 4, 2, 1, 2, 1, S<4, 1>, S<2, 1>, S<4, 1, 1, 2>, S<1, 1, 8, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<4, 1, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 1, 2>, S<1, 1, 4, 2>, S<4, 1, 2, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 2>, S<0, 1, 2, 3, 4, 5>, 5, 1>, + DeviceBatchedGemmMultipleD_Dl< Row, Row, Empty_Tuple, Row, int8_t, int8_t, int32_t, Empty_Tuple, int8_t, PassThrough, PassThrough, PassThrough, GemmMNPadding, 8, 8, 8, 4, 2, 1, 2, 1, S<1, 4>, S<1, 2>, S<4, 1, 1, 2>, S<1, 1, 8, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<4, 1, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 1, 2>, S<1, 1, 4, 2>, S<4, 1, 2, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 2>, S<0, 1, 2, 3, 4, 5>, 5, 1>, + DeviceBatchedGemmMultipleD_Dl< Row, Row, Empty_Tuple, Row, int8_t, int8_t, int32_t, Empty_Tuple, int8_t, PassThrough, PassThrough, PassThrough, GemmMNPadding, 8, 8, 8, 4, 2, 2, 1, 1, S<2, 1>, S<4, 1>, S<4, 1, 1, 2>, S<1, 1, 8, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<4, 1, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 1, 2>, S<1, 1, 4, 2>, S<4, 1, 2, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 2>, S<0, 1, 2, 3, 4, 5>, 5, 1>, + DeviceBatchedGemmMultipleD_Dl< Row, Row, Empty_Tuple, Row, int8_t, int8_t, int32_t, Empty_Tuple, int8_t, PassThrough, PassThrough, PassThrough, GemmMNPadding, 8, 8, 8, 4, 2, 2, 1, 1, S<1, 2>, S<1, 4>, S<4, 1, 1, 2>, S<1, 1, 8, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<4, 1, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 1, 2>, S<1, 1, 4, 2>, S<4, 1, 2, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 2>, S<0, 1, 2, 3, 4, 5>, 5, 1> + // clang-format on + >; + +void add_device_batched_gemm_multi_d_dl_i8_i8_i8_gmk_gkn_gmn_irregular_instances( + std::vector>>& instances) +{ + add_device_operation_instances( + instances, device_batched_gemm_multi_d_dl_i8_i8_i8_gmk_gkn_gmn_irregular_instances{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/batched_gemm_multi_d/device_batched_gemm_multi_d_dl_i8_i8_i8_gmk_gnk_gmn_instance.cpp b/library/src/tensor_operation_instance/gpu/batched_gemm_multi_d/device_batched_gemm_multi_d_dl_i8_i8_i8_gmk_gnk_gmn_instance.cpp new file mode 100644 index 0000000000000000000000000000000000000000..29d20e58ddbd3ad646d107dec278a4821936b35d --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/batched_gemm_multi_d/device_batched_gemm_multi_d_dl_i8_i8_i8_gmk_gnk_gmn_instance.cpp @@ -0,0 +1,93 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/device_batched_gemm_multi_d.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_batched_gemm_multiple_d_dl.hpp" +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +using Row = ck::tensor_layout::gemm::RowMajor; +using Col = ck::tensor_layout::gemm::ColumnMajor; + +template +using S = ck::Sequence; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; +using Empty_Tuple = ck::Tuple<>; + +static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default; + +// Compilation parameters for a[m, k] * b[n, k] = c[m, n] +using device_batched_gemm_multi_d_dl_i8_i8_i8_gmk_gnk_gmn_instances = std::tuple< + // clang-format off + // ##########################| ALayout| BLayout| DsLayout| CLayout| AData| BData| AccData| DsData| CData| A| B| C| GEMM| Block| MPer| NPer| K0Per| K1| M1Per| N1Per| KPer| M11N11Thread| M11N11Thread| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| CThreadTransfer| CThreadTransfer| CThreadTransfer| + // ##########################| | | | | Type| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise| Specialization| Size| Block| Block| Block| | ThreadM111| ThreadN111| Thread| ClusterM110Xs| ClusterN110Xs| ThreadSliceLengths| ThreadClusterLengths| ThreadCluster| SrcAccess| SrcVectorTensor| SrcVectorTensor| DstVectorTensor| ThreadSliceLengths| ThreadClusterLengths| ThreadCluster| SrcAccess| SrcVectorTensor| SrcVectorTensor| DstVectorTensor| SrcDstAccess| SrcDstVectorDim| DstScalarPerVector| + // ##########################| | | | | | | | | | Operation| Operation| Operation| | | | | | | | | | | | K0_M0_M1_K1| K0_M0_M1_K1| ArrangeOrder| Order| Lengths_K0_M0_M1_K1| ContiguousDimOrder| Lengths_K0_M0_M1_K1| K0_M0_M1_K1| K0_M0_M1_K1| ArrangeOrder| Order| Lengths_K0_M0_M1_K1| ContiguousDimOrder| Lengths_K0_M0_M1_K1| Order| | | + // ##########################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + // MPerBlock=128, NPerBlock=128 + DeviceBatchedGemmMultipleD_Dl< Row, Col, Empty_Tuple, Row, int8_t, int8_t, int32_t, Empty_Tuple, int8_t, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 128, 128, 16, 4, 4, 4, 1, S<2, 8>, S<2, 8>, S<8, 1, 1, 4>, S<2, 1, 128, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 4>, S<1, 2, 0, 3>, S<1, 1, 1, 4>, S<8, 1, 1, 4>, S<2, 1, 128, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 4>, S<1, 2, 0, 3>, S<1, 1, 1, 4>, S<0, 1, 2, 3, 4, 5>, 5, 4>, + DeviceBatchedGemmMultipleD_Dl< Row, Col, Empty_Tuple, Row, int8_t, int8_t, int32_t, Empty_Tuple, int8_t, PassThrough, PassThrough, PassThrough, GemmDefault, 128, 128, 128, 16, 4, 4, 8, 1, S<8, 2>, S<4, 2>, S<8, 1, 2, 4>, S<2, 1, 64, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 4>, S<1, 2, 0, 3>, S<1, 1, 1, 4>, S<8, 1, 2, 4>, S<2, 1, 64, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 4>, S<1, 2, 0, 3>, S<1, 1, 1, 4>, S<0, 1, 2, 3, 4, 5>, 5, 4>, + DeviceBatchedGemmMultipleD_Dl< Row, Col, Empty_Tuple, Row, int8_t, int8_t, int32_t, Empty_Tuple, int8_t, PassThrough, PassThrough, PassThrough, GemmDefault, 128, 128, 128, 16, 4, 4, 8, 1, S<2, 8>, S<2, 4>, S<8, 1, 2, 4>, S<2, 1, 64, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 4>, S<1, 2, 0, 3>, S<1, 1, 1, 4>, S<8, 1, 2, 4>, S<2, 1, 64, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 4>, S<1, 2, 0, 3>, S<1, 1, 1, 4>, S<0, 1, 2, 3, 4, 5>, 5, 4>, + // // MPerBlock=128, NPerBlock=64 + DeviceBatchedGemmMultipleD_Dl< Row, Col, Empty_Tuple, Row, int8_t, int8_t, int32_t, Empty_Tuple, int8_t, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 128, 64, 16, 4, 4, 2, 1, S<2, 8>, S<2, 8>, S<8, 1, 1, 4>, S<2, 1, 128, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 4>, S<1, 2, 0, 3>, S<1, 1, 1, 4>, S<8, 1, 1, 4>, S<2, 1, 64, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 4>, S<1, 2, 0, 3>, S<1, 1, 1, 4>, S<0, 1, 2, 3, 4, 5>, 5, 2>, + // // MPerBlock=64, NPerBlock=128 + DeviceBatchedGemmMultipleD_Dl< Row, Col, Empty_Tuple, Row, int8_t, int8_t, int32_t, Empty_Tuple, int8_t, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 64, 128, 16, 4, 2, 4, 1, S<2, 8>, S<2, 8>, S<8, 1, 1, 4>, S<2, 1, 64, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 4>, S<1, 2, 0, 3>, S<1, 1, 1, 4>, S<8, 1, 1, 4>, S<2, 1, 128, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 4>, S<1, 2, 0, 3>, S<1, 1, 1, 4>, S<0, 1, 2, 3, 4, 5>, 5, 4>, + // MPerBlock=64, NPerBlock=64 + DeviceBatchedGemmMultipleD_Dl< Row, Col, Empty_Tuple, Row, int8_t, int8_t, int32_t, Empty_Tuple, int8_t, PassThrough, PassThrough, PassThrough, GemmDefault, 64, 64, 64, 8, 4, 4, 4, 1, S<4, 2>, S<4, 2>, S<4, 1, 2, 4>, S<2, 1, 32, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 4>, S<1, 2, 0, 3>, S<1, 1, 1, 4>, S<4, 1, 2, 4>, S<2, 1, 32, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 4>, S<1, 2, 0, 3>, S<1, 1, 1, 4>, S<0, 1, 2, 3, 4, 5>, 5, 4>, + DeviceBatchedGemmMultipleD_Dl< Row, Col, Empty_Tuple, Row, int8_t, int8_t, int32_t, Empty_Tuple, int8_t, PassThrough, PassThrough, PassThrough, GemmDefault, 64, 64, 64, 8, 4, 4, 4, 1, S<2, 4>, S<2, 4>, S<4, 1, 2, 4>, S<2, 1, 32, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 4>, S<1, 2, 0, 3>, S<1, 1, 1, 4>, S<4, 1, 2, 4>, S<2, 1, 32, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 4>, S<1, 2, 0, 3>, S<1, 1, 1, 4>, S<0, 1, 2, 3, 4, 5>, 5, 4>, + DeviceBatchedGemmMultipleD_Dl< Row, Col, Empty_Tuple, Row, int8_t, int8_t, int32_t, Empty_Tuple, int8_t, PassThrough, PassThrough, PassThrough, GemmDefault, 64, 64, 64, 8, 4, 4, 4, 1, S<8, 1>, S<4, 2>, S<4, 1, 2, 4>, S<2, 1, 32, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 4>, S<1, 2, 0, 3>, S<1, 1, 1, 4>, S<4, 1, 2, 4>, S<2, 1, 32, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 4>, S<1, 2, 0, 3>, S<1, 1, 1, 4>, S<0, 1, 2, 3, 4, 5>, 5, 4>, + DeviceBatchedGemmMultipleD_Dl< Row, Col, Empty_Tuple, Row, int8_t, int8_t, int32_t, Empty_Tuple, int8_t, PassThrough, PassThrough, PassThrough, GemmDefault, 64, 64, 64, 8, 4, 4, 4, 1, S<4, 2>, S<8, 1>, S<4, 1, 2, 4>, S<2, 1, 32, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 4>, S<1, 2, 0, 3>, S<1, 1, 1, 4>, S<4, 1, 2, 4>, S<2, 1, 32, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 4>, S<1, 2, 0, 3>, S<1, 1, 1, 4>, S<0, 1, 2, 3, 4, 5>, 5, 4>, + // MPerBlock=32, NPerBlock=32 + DeviceBatchedGemmMultipleD_Dl< Row, Col, Empty_Tuple, Row, int8_t, int8_t, int32_t, Empty_Tuple, int8_t, PassThrough, PassThrough, PassThrough, GemmDefault, 32, 32, 32, 8, 4, 4, 2, 1, S<2, 2>, S<2, 4>, S<4, 1, 2, 4>, S<2, 1, 16, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 4>, S<1, 2, 0, 3>, S<1, 1, 1, 4>, S<4, 1, 2, 4>, S<2, 1, 16, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 4>, S<1, 2, 0, 3>, S<1, 1, 1, 4>, S<0, 1, 2, 3, 4, 5>, 5, 2>, + // MPerBlock=16, NPerBlock=64 + DeviceBatchedGemmMultipleD_Dl< Row, Col, Empty_Tuple, Row, int8_t, int8_t, int32_t, Empty_Tuple, int8_t, PassThrough, PassThrough, PassThrough, GemmDefault, 64, 16, 64, 16, 2, 1, 4, 1, S<4, 2>, S<4, 2>, S<4, 1, 1, 2>, S<4, 1, 16, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<4, 1, 4, 2>, S<4, 1, 16, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<0, 1, 2, 3, 4, 5>, 5, 4>, + DeviceBatchedGemmMultipleD_Dl< Row, Col, Empty_Tuple, Row, int8_t, int8_t, int32_t, Empty_Tuple, int8_t, PassThrough, PassThrough, PassThrough, GemmDefault, 64, 16, 64, 16, 2, 1, 4, 1, S<2, 4>, S<2, 4>, S<4, 1, 1, 2>, S<4, 1, 16, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<4, 1, 4, 2>, S<4, 1, 16, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<0, 1, 2, 3, 4, 5>, 5, 4>, + // MPerBlock=64, NPerBlock=16 + DeviceBatchedGemmMultipleD_Dl< Row, Col, Empty_Tuple, Row, int8_t, int8_t, int32_t, Empty_Tuple, int8_t, PassThrough, PassThrough, PassThrough, GemmDefault, 64, 64, 16, 16, 2, 4, 1, 1, S<4, 2>, S<4, 2>, S<4, 1, 4, 2>, S<4, 1, 16, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<4, 1, 1, 2>, S<4, 1, 16, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<0, 1, 2, 3, 4, 5>, 5, 1>, + DeviceBatchedGemmMultipleD_Dl< Row, Col, Empty_Tuple, Row, int8_t, int8_t, int32_t, Empty_Tuple, int8_t, PassThrough, PassThrough, PassThrough, GemmDefault, 64, 64, 16, 16, 2, 4, 1, 1, S<2, 4>, S<2, 4>, S<4, 1, 4, 2>, S<4, 1, 16, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<4, 1, 1, 2>, S<4, 1, 16, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<0, 1, 2, 3, 4, 5>, 5, 1>, + // MPerBlock=16, NPerBlock=16 + DeviceBatchedGemmMultipleD_Dl< Row, Col, Empty_Tuple, Row, int8_t, int8_t, int32_t, Empty_Tuple, int8_t, PassThrough, PassThrough, PassThrough, GemmDefault, 16, 16, 16, 16, 2, 2, 2, 1, S<4, 1>, S<4, 1>, S<4, 1, 4, 2>, S<4, 1, 4, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<4, 1, 4, 2>, S<4, 1, 4, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<0, 1, 2, 3, 4, 5>, 5, 2>, + // MPerBlock=8, NPerBlock=64 + DeviceBatchedGemmMultipleD_Dl< Row, Col, Empty_Tuple, Row, int8_t, int8_t, int32_t, Empty_Tuple, int8_t, PassThrough, PassThrough, PassThrough, GemmDefault, 64, 8, 64, 32, 2, 1, 2, 1, S<4, 1>, S<8, 2>, S<4, 1, 1, 2>, S<8, 1, 8, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<8, 1, 4, 2>, S<4, 1, 16, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<0, 1, 2, 3, 4, 5>, 5, 2>, + DeviceBatchedGemmMultipleD_Dl< Row, Col, Empty_Tuple, Row, int8_t, int8_t, int32_t, Empty_Tuple, int8_t, PassThrough, PassThrough, PassThrough, GemmDefault, 64, 8, 64, 32, 2, 1, 2, 1, S<2, 2>, S<8, 2>, S<4, 1, 1, 2>, S<8, 1, 8, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<8, 1, 4, 2>, S<4, 1, 16, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<0, 1, 2, 3, 4, 5>, 5, 2>, + // MPerBlock=64, NPerBlock=8 + DeviceBatchedGemmMultipleD_Dl< Row, Col, Empty_Tuple, Row, int8_t, int8_t, int32_t, Empty_Tuple, int8_t, PassThrough, PassThrough, PassThrough, GemmDefault, 64, 64, 8, 32, 2, 2, 1, 1, S<8, 2>, S<4, 1>, S<8, 1, 4, 2>, S<4, 1, 16, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<4, 1, 1, 2>, S<8, 1, 8, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<0, 1, 2, 3, 4, 5>, 5, 1>, + DeviceBatchedGemmMultipleD_Dl< Row, Col, Empty_Tuple, Row, int8_t, int8_t, int32_t, Empty_Tuple, int8_t, PassThrough, PassThrough, PassThrough, GemmDefault, 64, 64, 8, 32, 2, 2, 1, 1, S<8, 2>, S<2, 2>, S<8, 1, 4, 2>, S<4, 1, 16, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<4, 1, 1, 2>, S<8, 1, 8, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<0, 1, 2, 3, 4, 5>, 5, 1>, + // MPerBlock=8, NPerBlock=8 + DeviceBatchedGemmMultipleD_Dl< Row, Col, Empty_Tuple, Row, int8_t, int8_t, int32_t, Empty_Tuple, int8_t, PassThrough, PassThrough, PassThrough, GemmDefault, 8, 8, 8, 4, 2, 1, 2, 1, S<4, 1>, S<2, 1>, S<4, 1, 1, 2>, S<1, 1, 8, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<4, 1, 1, 2>, S<1, 1, 8, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<0, 1, 2, 3, 4, 5>, 5, 2>, + DeviceBatchedGemmMultipleD_Dl< Row, Col, Empty_Tuple, Row, int8_t, int8_t, int32_t, Empty_Tuple, int8_t, PassThrough, PassThrough, PassThrough, GemmDefault, 8, 8, 8, 4, 2, 1, 2, 1, S<1, 4>, S<1, 2>, S<4, 1, 1, 2>, S<1, 1, 8, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<4, 1, 1, 2>, S<1, 1, 8, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<0, 1, 2, 3, 4, 5>, 5, 2>, + DeviceBatchedGemmMultipleD_Dl< Row, Col, Empty_Tuple, Row, int8_t, int8_t, int32_t, Empty_Tuple, int8_t, PassThrough, PassThrough, PassThrough, GemmDefault, 8, 8, 8, 4, 2, 2, 1, 1, S<2, 1>, S<4, 1>, S<4, 1, 1, 2>, S<1, 1, 8, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<4, 1, 1, 2>, S<1, 1, 8, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<0, 1, 2, 3, 4, 5>, 5, 1>, + DeviceBatchedGemmMultipleD_Dl< Row, Col, Empty_Tuple, Row, int8_t, int8_t, int32_t, Empty_Tuple, int8_t, PassThrough, PassThrough, PassThrough, GemmDefault, 8, 8, 8, 4, 2, 2, 1, 1, S<1, 2>, S<1, 4>, S<4, 1, 1, 2>, S<1, 1, 8, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<4, 1, 1, 2>, S<1, 1, 8, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<0, 1, 2, 3, 4, 5>, 5, 1> + // clang-format on + >; + +void add_device_batched_gemm_multi_d_dl_i8_i8_i8_gmk_gnk_gmn_instances( + std::vector>>& instances) +{ + add_device_operation_instances(instances, + device_batched_gemm_multi_d_dl_i8_i8_i8_gmk_gnk_gmn_instances{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/batched_gemm_multi_d/device_batched_gemm_multi_d_dl_i8_i8_i8_gmk_gnk_gmn_irregular_instance.cpp b/library/src/tensor_operation_instance/gpu/batched_gemm_multi_d/device_batched_gemm_multi_d_dl_i8_i8_i8_gmk_gnk_gmn_irregular_instance.cpp new file mode 100644 index 0000000000000000000000000000000000000000..dd0f8749c00dceeabc834787dfb7c907c3ede909 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/batched_gemm_multi_d/device_batched_gemm_multi_d_dl_i8_i8_i8_gmk_gnk_gmn_irregular_instance.cpp @@ -0,0 +1,90 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/device_batched_gemm_multi_d.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_batched_gemm_multiple_d_dl.hpp" +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +using Row = ck::tensor_layout::gemm::RowMajor; +using Col = ck::tensor_layout::gemm::ColumnMajor; + +template +using S = ck::Sequence; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; +using Empty_Tuple = ck::Tuple<>; + +static constexpr auto GemmMNPadding = ck::tensor_operation::device::GemmSpecialization::MNPadding; + +// Compilation parameters for a[m, k] * b[n, k] = c[m, n] +using device_batched_gemm_multi_d_dl_i8_i8_i8_gmk_gnk_gmn_irregular_instances = std::tuple< + // clang-format off + // ##########################| ALayout| BLayout| DsLayout| CLayout| AData| BData| AccData| DsData| CData| A| B| C| GEMM| Block| MPer| NPer| K0Per| K1| M1Per| N1Per| KPer| M11N11Thread| M11N11Thread| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| CThreadTransfer| CThreadTransfer| CThreadTransfer| + // ##########################| | | | | Type| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise| Specialization| Size| Block| Block| Block| | ThreadM111| ThreadN111| Thread| ClusterM110Xs| ClusterN110Xs| ThreadSliceLengths| ThreadClusterLengths| ThreadCluster| SrcAccess| SrcVectorTensor| SrcVectorTensor| DstVectorTensor| ThreadSliceLengths| ThreadClusterLengths| ThreadCluster| SrcAccess| SrcVectorTensor| SrcVectorTensor| DstVectorTensor| SrcDstAccess| SrcDstVectorDim| DstScalarPerVector| + // ##########################| | | | | | | | | | Operation| Operation| Operation| | | | | | | | | | | | K0_M0_M1_K1| K0_M0_M1_K1| ArrangeOrder| Order| Lengths_K0_M0_M1_K1| ContiguousDimOrder| Lengths_K0_M0_M1_K1| K0_M0_M1_K1| K0_M0_M1_K1| ArrangeOrder| Order| Lengths_K0_M0_M1_K1| ContiguousDimOrder| Lengths_K0_M0_M1_K1| Order| | | + // ##########################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + // MPerBlock=128, NPerBlock=128 + DeviceBatchedGemmMultipleD_Dl< Row, Col, Empty_Tuple, Row, int8_t, int8_t, int32_t, Empty_Tuple, int8_t, PassThrough, PassThrough, PassThrough, GemmMNPadding, 128, 128, 128, 16, 4, 4, 8, 1, S<8, 2>, S<4, 2>, S<8, 1, 2, 4>, S<2, 1, 64, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 4>, S<1, 2, 0, 3>, S<1, 1, 1, 4>, S<8, 1, 2, 4>, S<2, 1, 64, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 4>, S<1, 2, 0, 3>, S<1, 1, 1, 4>, S<0, 1, 2, 3, 4, 5>, 5, 1>, + DeviceBatchedGemmMultipleD_Dl< Row, Col, Empty_Tuple, Row, int8_t, int8_t, int32_t, Empty_Tuple, int8_t, PassThrough, PassThrough, PassThrough, GemmMNPadding, 128, 128, 128, 16, 4, 4, 8, 1, S<4, 4>, S<4, 2>, S<8, 1, 2, 4>, S<2, 1, 64, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 4>, S<1, 2, 0, 3>, S<1, 1, 1, 4>, S<8, 1, 2, 4>, S<2, 1, 64, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 4>, S<1, 2, 0, 3>, S<1, 1, 1, 4>, S<0, 1, 2, 3, 4, 5>, 5, 1>, + DeviceBatchedGemmMultipleD_Dl< Row, Col, Empty_Tuple, Row, int8_t, int8_t, int32_t, Empty_Tuple, int8_t, PassThrough, PassThrough, PassThrough, GemmMNPadding, 128, 128, 128, 16, 4, 4, 8, 1, S<2, 8>, S<2, 4>, S<8, 1, 2, 4>, S<2, 1, 64, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 4>, S<1, 2, 0, 3>, S<1, 1, 1, 4>, S<8, 1, 2, 4>, S<2, 1, 64, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 4>, S<1, 2, 0, 3>, S<1, 1, 1, 4>, S<0, 1, 2, 3, 4, 5>, 5, 1>, + // // MPerBlock=128, NPerBlock=64 + DeviceBatchedGemmMultipleD_Dl< Row, Col, Empty_Tuple, Row, int8_t, int8_t, int32_t, Empty_Tuple, int8_t, PassThrough, PassThrough, PassThrough, GemmMNPadding, 256, 128, 64, 16, 4, 4, 2, 1, S<4, 4>, S<4, 4>, S<8, 1, 1, 4>, S<2, 1, 128, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 4>, S<1, 2, 0, 3>, S<1, 1, 1, 4>, S<8, 1, 1, 4>, S<2, 1, 64, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 4>, S<1, 2, 0, 3>, S<1, 1, 1, 4>, S<0, 1, 2, 3, 4, 5>, 5, 1>, + DeviceBatchedGemmMultipleD_Dl< Row, Col, Empty_Tuple, Row, int8_t, int8_t, int32_t, Empty_Tuple, int8_t, PassThrough, PassThrough, PassThrough, GemmMNPadding, 256, 128, 64, 16, 4, 4, 2, 1, S<2, 8>, S<2, 8>, S<8, 1, 1, 4>, S<2, 1, 128, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 4>, S<1, 2, 0, 3>, S<1, 1, 1, 4>, S<8, 1, 1, 4>, S<2, 1, 64, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 4>, S<1, 2, 0, 3>, S<1, 1, 1, 4>, S<0, 1, 2, 3, 4, 5>, 5, 1>, + // // MPerBlock=64, NPerBlock=128 + DeviceBatchedGemmMultipleD_Dl< Row, Col, Empty_Tuple, Row, int8_t, int8_t, int32_t, Empty_Tuple, int8_t, PassThrough, PassThrough, PassThrough, GemmMNPadding, 256, 64, 128, 16, 4, 2, 4, 1, S<4, 4>, S<4, 4>, S<8, 1, 1, 4>, S<2, 1, 64, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 4>, S<1, 2, 0, 3>, S<1, 1, 1, 4>, S<8, 1, 1, 4>, S<2, 1, 128, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 4>, S<1, 2, 0, 3>, S<1, 1, 1, 4>, S<0, 1, 2, 3, 4, 5>, 5, 1>, + DeviceBatchedGemmMultipleD_Dl< Row, Col, Empty_Tuple, Row, int8_t, int8_t, int32_t, Empty_Tuple, int8_t, PassThrough, PassThrough, PassThrough, GemmMNPadding, 256, 64, 128, 16, 4, 2, 4, 1, S<2, 8>, S<2, 8>, S<8, 1, 1, 4>, S<2, 1, 64, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 4>, S<1, 2, 0, 3>, S<1, 1, 1, 4>, S<8, 1, 1, 4>, S<2, 1, 128, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 4>, S<1, 2, 0, 3>, S<1, 1, 1, 4>, S<0, 1, 2, 3, 4, 5>, 5, 1>, + // MPerBlock=64, NPerBlock=64 + DeviceBatchedGemmMultipleD_Dl< Row, Col, Empty_Tuple, Row, int8_t, int8_t, int32_t, Empty_Tuple, int8_t, PassThrough, PassThrough, PassThrough, GemmMNPadding, 64, 64, 64, 8, 4, 4, 4, 1, S<4, 2>, S<4, 2>, S<4, 1, 2, 4>, S<2, 1, 32, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 4>, S<1, 2, 0, 3>, S<1, 1, 1, 4>, S<4, 1, 2, 4>, S<2, 1, 32, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 4>, S<1, 2, 0, 3>, S<1, 1, 1, 4>, S<0, 1, 2, 3, 4, 5>, 5, 1>, + DeviceBatchedGemmMultipleD_Dl< Row, Col, Empty_Tuple, Row, int8_t, int8_t, int32_t, Empty_Tuple, int8_t, PassThrough, PassThrough, PassThrough, GemmMNPadding, 64, 64, 64, 8, 4, 4, 4, 1, S<2, 4>, S<2, 4>, S<4, 1, 2, 4>, S<2, 1, 32, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 4>, S<1, 2, 0, 3>, S<1, 1, 1, 4>, S<4, 1, 2, 4>, S<2, 1, 32, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 4>, S<1, 2, 0, 3>, S<1, 1, 1, 4>, S<0, 1, 2, 3, 4, 5>, 5, 1>, + DeviceBatchedGemmMultipleD_Dl< Row, Col, Empty_Tuple, Row, int8_t, int8_t, int32_t, Empty_Tuple, int8_t, PassThrough, PassThrough, PassThrough, GemmMNPadding, 64, 64, 64, 8, 4, 4, 4, 1, S<8, 1>, S<4, 2>, S<4, 1, 2, 4>, S<2, 1, 32, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 4>, S<1, 2, 0, 3>, S<1, 1, 1, 4>, S<4, 1, 2, 4>, S<2, 1, 32, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 4>, S<1, 2, 0, 3>, S<1, 1, 1, 4>, S<0, 1, 2, 3, 4, 5>, 5, 1>, + DeviceBatchedGemmMultipleD_Dl< Row, Col, Empty_Tuple, Row, int8_t, int8_t, int32_t, Empty_Tuple, int8_t, PassThrough, PassThrough, PassThrough, GemmMNPadding, 64, 64, 64, 8, 4, 4, 4, 1, S<4, 2>, S<8, 1>, S<4, 1, 2, 4>, S<2, 1, 32, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 4>, S<1, 2, 0, 3>, S<1, 1, 1, 4>, S<4, 1, 2, 4>, S<2, 1, 32, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 4>, S<1, 2, 0, 3>, S<1, 1, 1, 4>, S<0, 1, 2, 3, 4, 5>, 5, 1>, + // MPerBlock=32, NPerBlock=32 + DeviceBatchedGemmMultipleD_Dl< Row, Col, Empty_Tuple, Row, int8_t, int8_t, int32_t, Empty_Tuple, int8_t, PassThrough, PassThrough, PassThrough, GemmMNPadding, 32, 32, 32, 8, 4, 2, 4, 1, S<4, 2>, S<2, 2>, S<4, 1, 2, 4>, S<2, 1, 16, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 4>, S<1, 2, 0, 3>, S<1, 1, 1, 4>, S<4, 1, 2, 4>, S<2, 1, 16, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 4>, S<1, 2, 0, 3>, S<1, 1, 1, 4>, S<0, 1, 2, 3, 4, 5>, 5, 1>, + DeviceBatchedGemmMultipleD_Dl< Row, Col, Empty_Tuple, Row, int8_t, int8_t, int32_t, Empty_Tuple, int8_t, PassThrough, PassThrough, PassThrough, GemmMNPadding, 32, 32, 32, 8, 4, 4, 2, 1, S<2, 2>, S<4, 2>, S<4, 1, 2, 4>, S<2, 1, 16, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 4>, S<1, 2, 0, 3>, S<1, 1, 1, 4>, S<4, 1, 2, 4>, S<2, 1, 16, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 4>, S<1, 2, 0, 3>, S<1, 1, 1, 4>, S<0, 1, 2, 3, 4, 5>, 5, 1>, + DeviceBatchedGemmMultipleD_Dl< Row, Col, Empty_Tuple, Row, int8_t, int8_t, int32_t, Empty_Tuple, int8_t, PassThrough, PassThrough, PassThrough, GemmMNPadding, 32, 32, 32, 8, 4, 4, 2, 1, S<2, 2>, S<2, 4>, S<4, 1, 2, 4>, S<2, 1, 16, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 4>, S<1, 2, 0, 3>, S<1, 1, 1, 4>, S<4, 1, 2, 4>, S<2, 1, 16, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 4>, S<1, 2, 0, 3>, S<1, 1, 1, 4>, S<0, 1, 2, 3, 4, 5>, 5, 1>, + // MPerBlock=16, NPerBlock=16 + DeviceBatchedGemmMultipleD_Dl< Row, Col, Empty_Tuple, Row, int8_t, int8_t, int32_t, Empty_Tuple, int8_t, PassThrough, PassThrough, PassThrough, GemmMNPadding, 16, 16, 16, 16, 2, 2, 2, 1, S<2, 2>, S<2, 2>, S<4, 1, 4, 2>, S<4, 1, 4, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<4, 1, 4, 2>, S<4, 1, 4, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<0, 1, 2, 3, 4, 5>, 5, 1>, + DeviceBatchedGemmMultipleD_Dl< Row, Col, Empty_Tuple, Row, int8_t, int8_t, int32_t, Empty_Tuple, int8_t, PassThrough, PassThrough, PassThrough, GemmMNPadding, 16, 16, 16, 16, 2, 2, 2, 1, S<4, 1>, S<4, 1>, S<4, 1, 4, 2>, S<4, 1, 4, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<4, 1, 4, 2>, S<4, 1, 4, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<0, 1, 2, 3, 4, 5>, 5, 1>, + // MPerBlock=8, NPerBlock=64 + DeviceBatchedGemmMultipleD_Dl< Row, Col, Empty_Tuple, Row, int8_t, int8_t, int32_t, Empty_Tuple, int8_t, PassThrough, PassThrough, PassThrough, GemmMNPadding, 64, 8, 64, 32, 2, 1, 2, 1, S<2, 2>, S<8, 2>, S<4, 1, 1, 2>, S<8, 1, 8, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<8, 1, 4, 2>, S<4, 1, 16, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<0, 1, 2, 3, 4, 5>, 5, 1>, + // MPerBlock=64, NPerBlock=8 + DeviceBatchedGemmMultipleD_Dl< Row, Col, Empty_Tuple, Row, int8_t, int8_t, int32_t, Empty_Tuple, int8_t, PassThrough, PassThrough, PassThrough, GemmMNPadding, 64, 64, 8, 32, 2, 2, 1, 1, S<8, 2>, S<2, 2>, S<8, 1, 4, 2>, S<4, 1, 16, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<4, 1, 1, 2>, S<8, 1, 8, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<0, 1, 2, 3, 4, 5>, 5, 1>, + // MPerBlock=8, NPerBlock=8 + DeviceBatchedGemmMultipleD_Dl< Row, Col, Empty_Tuple, Row, int8_t, int8_t, int32_t, Empty_Tuple, int8_t, PassThrough, PassThrough, PassThrough, GemmMNPadding, 8, 8, 8, 4, 2, 1, 2, 1, S<4, 1>, S<2, 1>, S<4, 1, 1, 2>, S<1, 1, 8, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<4, 1, 1, 2>, S<1, 1, 8, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<0, 1, 2, 3, 4, 5>, 5, 1>, + DeviceBatchedGemmMultipleD_Dl< Row, Col, Empty_Tuple, Row, int8_t, int8_t, int32_t, Empty_Tuple, int8_t, PassThrough, PassThrough, PassThrough, GemmMNPadding, 8, 8, 8, 4, 2, 1, 2, 1, S<1, 4>, S<1, 2>, S<4, 1, 1, 2>, S<1, 1, 8, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<4, 1, 1, 2>, S<1, 1, 8, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<0, 1, 2, 3, 4, 5>, 5, 1>, + DeviceBatchedGemmMultipleD_Dl< Row, Col, Empty_Tuple, Row, int8_t, int8_t, int32_t, Empty_Tuple, int8_t, PassThrough, PassThrough, PassThrough, GemmMNPadding, 8, 8, 8, 4, 2, 2, 1, 1, S<2, 1>, S<4, 1>, S<4, 1, 1, 2>, S<1, 1, 8, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<4, 1, 1, 2>, S<1, 1, 8, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<0, 1, 2, 3, 4, 5>, 5, 1>, + DeviceBatchedGemmMultipleD_Dl< Row, Col, Empty_Tuple, Row, int8_t, int8_t, int32_t, Empty_Tuple, int8_t, PassThrough, PassThrough, PassThrough, GemmMNPadding, 8, 8, 8, 4, 2, 2, 1, 1, S<1, 2>, S<1, 4>, S<4, 1, 1, 2>, S<1, 1, 8, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<4, 1, 1, 2>, S<1, 1, 8, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<0, 1, 2, 3, 4, 5>, 5, 1> + // clang-format on + >; + +void add_device_batched_gemm_multi_d_dl_i8_i8_i8_gmk_gnk_gmn_irregular_instances( + std::vector>>& instances) +{ + add_device_operation_instances( + instances, device_batched_gemm_multi_d_dl_i8_i8_i8_gmk_gnk_gmn_irregular_instances{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/batched_gemm_reduce/CMakeLists.txt b/library/src/tensor_operation_instance/gpu/batched_gemm_reduce/CMakeLists.txt index db3719cff8a16a7f8dd2a42cbb8da1009555070a..728e35fc3be2ab845d30f9576b01e12aa6d57948 100644 --- a/library/src/tensor_operation_instance/gpu/batched_gemm_reduce/CMakeLists.txt +++ b/library/src/tensor_operation_instance/gpu/batched_gemm_reduce/CMakeLists.txt @@ -1,7 +1,8 @@ +if(DTYPES MATCHES "fp16" OR DTYPES MATCHES "fp32" OR NOT DEFINED DTYPES) add_instance_library(device_batched_gemm_reduce_instance device_batched_gemm_reduce_xdl_cshuffle_f16_f16_f16_f32_f32_gmk_gkn_gmn_instance.cpp device_batched_gemm_reduce_xdl_cshuffle_f16_f16_f16_f32_f32_gmk_gnk_gmn_instance.cpp device_batched_gemm_reduce_xdl_cshuffle_f16_f16_f16_f32_f32_gkm_gkn_gmn_instance.cpp device_batched_gemm_reduce_xdl_cshuffle_f16_f16_f16_f32_f32_gkm_gnk_gmn_instance.cpp ) - +endif() diff --git a/library/src/tensor_operation_instance/gpu/batched_gemm_reduce/device_batched_gemm_reduce_xdl_cshuffle_f16_f16_f16_f32_f32_gkm_gkn_gmn_instance.cpp b/library/src/tensor_operation_instance/gpu/batched_gemm_reduce/device_batched_gemm_reduce_xdl_cshuffle_f16_f16_f16_f32_f32_gkm_gkn_gmn_instance.cpp index 521c3d9219b1cd23e30ca194d3e22d1f38fa9316..cb89d3cefd6525d017d7f7b4685be456cb1c7e44 100644 --- a/library/src/tensor_operation_instance/gpu/batched_gemm_reduce/device_batched_gemm_reduce_xdl_cshuffle_f16_f16_f16_f32_f32_gkm_gkn_gmn_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/batched_gemm_reduce/device_batched_gemm_reduce_xdl_cshuffle_f16_f16_f16_f32_f32_gkm_gkn_gmn_instance.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #include diff --git a/library/src/tensor_operation_instance/gpu/batched_gemm_reduce/device_batched_gemm_reduce_xdl_cshuffle_f16_f16_f16_f32_f32_gkm_gnk_gmn_instance.cpp b/library/src/tensor_operation_instance/gpu/batched_gemm_reduce/device_batched_gemm_reduce_xdl_cshuffle_f16_f16_f16_f32_f32_gkm_gnk_gmn_instance.cpp index 231d612d781bc93c2a232f7e1478fddfb09c1b5d..91eefba0c10af87e7674cb053ee549a2799de7a9 100644 --- a/library/src/tensor_operation_instance/gpu/batched_gemm_reduce/device_batched_gemm_reduce_xdl_cshuffle_f16_f16_f16_f32_f32_gkm_gnk_gmn_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/batched_gemm_reduce/device_batched_gemm_reduce_xdl_cshuffle_f16_f16_f16_f32_f32_gkm_gnk_gmn_instance.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #include diff --git a/library/src/tensor_operation_instance/gpu/batched_gemm_reduce/device_batched_gemm_reduce_xdl_cshuffle_f16_f16_f16_f32_f32_gmk_gkn_gmn_instance.cpp b/library/src/tensor_operation_instance/gpu/batched_gemm_reduce/device_batched_gemm_reduce_xdl_cshuffle_f16_f16_f16_f32_f32_gmk_gkn_gmn_instance.cpp index 165bc3957d36479b1cb3052c0085ade58fb71322..c20798f557e58e870f748d7900f1418f3a1a5779 100644 --- a/library/src/tensor_operation_instance/gpu/batched_gemm_reduce/device_batched_gemm_reduce_xdl_cshuffle_f16_f16_f16_f32_f32_gmk_gkn_gmn_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/batched_gemm_reduce/device_batched_gemm_reduce_xdl_cshuffle_f16_f16_f16_f32_f32_gmk_gkn_gmn_instance.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #include diff --git a/library/src/tensor_operation_instance/gpu/batched_gemm_reduce/device_batched_gemm_reduce_xdl_cshuffle_f16_f16_f16_f32_f32_gmk_gnk_gmn_instance.cpp b/library/src/tensor_operation_instance/gpu/batched_gemm_reduce/device_batched_gemm_reduce_xdl_cshuffle_f16_f16_f16_f32_f32_gmk_gnk_gmn_instance.cpp index 832fc3b066066f9ac8faae30cebff714b73b53c9..3d9ad64b9b40a2271ed40608735bffb0133153f9 100644 --- a/library/src/tensor_operation_instance/gpu/batched_gemm_reduce/device_batched_gemm_reduce_xdl_cshuffle_f16_f16_f16_f32_f32_gmk_gnk_gmn_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/batched_gemm_reduce/device_batched_gemm_reduce_xdl_cshuffle_f16_f16_f16_f32_f32_gmk_gnk_gmn_instance.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #include diff --git a/library/src/tensor_operation_instance/gpu/batched_gemm_softmax_gemm/CMakeLists.txt b/library/src/tensor_operation_instance/gpu/batched_gemm_softmax_gemm/CMakeLists.txt index 29fce56610929805ea706dbbea6cdfe18fdcc44d..5ac55655d4d6ad3fec6f34e68098459073b2f5bd 100644 --- a/library/src/tensor_operation_instance/gpu/batched_gemm_softmax_gemm/CMakeLists.txt +++ b/library/src/tensor_operation_instance/gpu/batched_gemm_softmax_gemm/CMakeLists.txt @@ -1,4 +1,5 @@ +if(DTYPES MATCHES "fp16" OR NOT DEFINED DTYPES) add_instance_library(device_batched_gemm_softmax_gemm_instance device_batched_gemm_softmax_gemm_xdl_cshuffle_f16_f16_f16_f16_gmk_gnk_gno_gmo_instance.cpp ) - +endif() diff --git a/library/src/tensor_operation_instance/gpu/batched_gemm_softmax_gemm/device_batched_gemm_softmax_gemm_xdl_cshuffle_f16_f16_f16_f16_gmk_gnk_gno_gmo_instance.cpp b/library/src/tensor_operation_instance/gpu/batched_gemm_softmax_gemm/device_batched_gemm_softmax_gemm_xdl_cshuffle_f16_f16_f16_f16_gmk_gnk_gno_gmo_instance.cpp index 99e8712474996a6c30e8bf16ae69f5bc15925364..cf23d01bf234e2cdda3efb2b655879a4740f5c27 100644 --- a/library/src/tensor_operation_instance/gpu/batched_gemm_softmax_gemm/device_batched_gemm_softmax_gemm_xdl_cshuffle_f16_f16_f16_f16_gmk_gnk_gno_gmo_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/batched_gemm_softmax_gemm/device_batched_gemm_softmax_gemm_xdl_cshuffle_f16_f16_f16_f16_gmk_gnk_gno_gmo_instance.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #include diff --git a/library/src/tensor_operation_instance/gpu/batched_gemm_softmax_gemm_permute/CMakeLists.txt b/library/src/tensor_operation_instance/gpu/batched_gemm_softmax_gemm_permute/CMakeLists.txt index eba248e59978c867d53d64a6e43e9f2187a2e7d5..d5110e45057aea85585a25a2bf65a493e94be49e 100644 --- a/library/src/tensor_operation_instance/gpu/batched_gemm_softmax_gemm_permute/CMakeLists.txt +++ b/library/src/tensor_operation_instance/gpu/batched_gemm_softmax_gemm_permute/CMakeLists.txt @@ -1,7 +1,11 @@ -add_instance_library(device_batched_gemm_softmax_gemm_permute_instance - device_batched_gemm_softmax_gemm_permute_xdl_cshuffle_f16_f16_f16_f16_gmk_gnk_gno_gmo_instance.cpp - device_batched_gemm_softmax_gemm_permute_xdl_cshuffle_bf16_bf16_bf16_bf16_gmk_gnk_gno_gmo_instance.cpp - device_batched_gemm_bias_softmax_gemm_permute_xdl_cshuffle_f16_f16_f16_f16_gmk_gnk_gno_gmo_instance.cpp - device_batched_gemm_bias_softmax_gemm_permute_xdl_cshuffle_bf16_bf16_bf16_bf16_gmk_gnk_gno_gmo_instance.cpp -) +set(DEVICE_BATCHED_GEMM_SOFTMAX_GEMM_PERMUTE_INSTANCES) +if(DTYPES MATCHES "fp16" OR NOT DEFINED DTYPES) + list(APPEND DEVICE_BATCHED_GEMM_SOFTMAX_GEMM_PERMUTE_INSTANCES device_batched_gemm_softmax_gemm_permute_xdl_cshuffle_f16_f16_f16_f16_gmk_gnk_gno_gmo_instance.cpp) + list(APPEND DEVICE_BATCHED_GEMM_SOFTMAX_GEMM_PERMUTE_INSTANCES device_batched_gemm_bias_softmax_gemm_permute_xdl_cshuffle_f16_f16_f16_f16_gmk_gnk_gno_gmo_instance.cpp) +endif() +if(DTYPES MATCHES "bf16" OR NOT DEFINED DTYPES) + list(APPEND DEVICE_BATCHED_GEMM_SOFTMAX_GEMM_PERMUTE_INSTANCES device_batched_gemm_softmax_gemm_permute_xdl_cshuffle_bf16_bf16_bf16_bf16_gmk_gnk_gno_gmo_instance.cpp) + list(APPEND DEVICE_BATCHED_GEMM_SOFTMAX_GEMM_PERMUTE_INSTANCES device_batched_gemm_bias_softmax_gemm_permute_xdl_cshuffle_bf16_bf16_bf16_bf16_gmk_gnk_gno_gmo_instance.cpp) +endif() +add_instance_library(device_batched_gemm_softmax_gemm_permute_instance ${DEVICE_BATCHED_GEMM_SOFTMAX_GEMM_PERMUTE_INSTANCES}) diff --git a/library/src/tensor_operation_instance/gpu/batched_gemm_softmax_gemm_permute/device_batched_gemm_bias_softmax_gemm_permute_xdl_cshuffle_bf16_bf16_bf16_bf16_gmk_gnk_gno_gmo_instance.cpp b/library/src/tensor_operation_instance/gpu/batched_gemm_softmax_gemm_permute/device_batched_gemm_bias_softmax_gemm_permute_xdl_cshuffle_bf16_bf16_bf16_bf16_gmk_gnk_gno_gmo_instance.cpp index f73e3dea84ebe2f4e70af03df967bcd3b2bd7450..498bf58fb3e6cfaf444fc92923ff261389ac0ac8 100644 --- a/library/src/tensor_operation_instance/gpu/batched_gemm_softmax_gemm_permute/device_batched_gemm_bias_softmax_gemm_permute_xdl_cshuffle_bf16_bf16_bf16_bf16_gmk_gnk_gno_gmo_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/batched_gemm_softmax_gemm_permute/device_batched_gemm_bias_softmax_gemm_permute_xdl_cshuffle_bf16_bf16_bf16_bf16_gmk_gnk_gno_gmo_instance.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #include @@ -41,10 +41,11 @@ template , ck::Tuple<>, F32, BF16, PassThrough, PassThrough, ScaleAdd, PassThrough, PassThrough, GemmPadded, TensorDefault, TensorDefault, TensorDefault, TensorDefault, 1, 256, 128, 64, 32, 128, 32, 8, 8, 2, 32, 32, 1, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S< 8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 2, S<1, 32, 1, 8>, 8, MaskingSpec, 1>, DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle< NumDimG, NumDimM, NumDimN, NumDimK, NumDimO, BF16, BF16, BF16, BF16, ck::Tuple, ck::Tuple<>, F32, BF16, PassThrough, PassThrough, ScaleAdd, PassThrough, PassThrough, GemmDefault, TensorDefault, TensorDefault, TensorDefault, TensorDefault, 1, 256, 256, 128, 32, 64, 32, 8, 8, 2, 32, 32, 2, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<16, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 2, S<1, 32, 1, 8>, 8, MaskingSpec>, DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle< NumDimG, NumDimM, NumDimN, NumDimK, NumDimO, BF16, BF16, BF16, BF16, ck::Tuple, ck::Tuple<>, F32, BF16, PassThrough, PassThrough, ScaleAdd, PassThrough, PassThrough, GemmDefault, TensorDefault, TensorDefault, TensorDefault, TensorDefault, 1, 256, 256, 128, 32, 128, 32, 8, 8, 2, 32, 32, 2, 4, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S< 8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 2, S<1, 32, 1, 8>, 8, MaskingSpec>, DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle< NumDimG, NumDimM, NumDimN, NumDimK, NumDimO, BF16, BF16, BF16, BF16, ck::Tuple, ck::Tuple<>, F32, BF16, PassThrough, PassThrough, ScaleAdd, PassThrough, PassThrough, GemmDefault, TensorDefault, TensorDefault, TensorDefault, TensorDefault, 1, 256, 128, 256, 32, 64, 32, 8, 8, 2, 32, 32, 1, 8, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<16, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 2, S<1, 32, 1, 8>, 8, MaskingSpec>, @@ -58,8 +59,9 @@ using device_batched_gemm_bias_softmax_gemm_permute_xdl_cshuffle_bf16_bf16_bf16_ DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle< NumDimG, NumDimM, NumDimN, NumDimK, NumDimO, BF16, BF16, BF16, BF16, ck::Tuple, ck::Tuple<>, F32, BF16, PassThrough, PassThrough, ScaleAdd, PassThrough, PassThrough, GemmDefault, TensorDefault, TensorDefault, TensorDefault, TensorDefault, 1, 256, 64, 256, 64, 128, 32, 8, 8, 2, 16, 16, 1, 16, 8, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S< 8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 8, S<1, 16, 1,16>, 8, MaskingSpec>, DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle< NumDimG, NumDimM, NumDimN, NumDimK, NumDimO, BF16, BF16, BF16, BF16, ck::Tuple, ck::Tuple<>, F32, BF16, PassThrough, PassThrough, ScaleAdd, PassThrough, PassThrough, GemmDefault, TensorDefault, TensorDefault, TensorDefault, TensorDefault, 1, 256, 64, 256, 64, 64, 32, 8, 8, 2, 16, 16, 1, 16, 4, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<16, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 4, S<1, 32, 1, 8>, 8, MaskingSpec>, // Padded fallback kernel - DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle< NumDimG, NumDimM, NumDimN, NumDimK, NumDimO, BF16, BF16, BF16, BF16, ck::Tuple, ck::Tuple<>, F32, BF16, PassThrough, PassThrough, ScaleAdd, PassThrough, PassThrough, GemmPadded, TensorDefault, TensorDefault, TensorDefault, TensorDefault, 1, 256, 128, 128, 64, 128, 32, 8, 8, 2, 32, 32, 1, 4, 4, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, false, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, false, S< 8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 2, S<1, 32, 1, 8>, 8, MaskingSpec>, - DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle< NumDimG, NumDimM, NumDimN, NumDimK, NumDimO, BF16, BF16, BF16, BF16, ck::Tuple, ck::Tuple<>, F32, BF16, PassThrough, PassThrough, ScaleAdd, PassThrough, PassThrough, GemmPadded, TensorDefault, TensorDefault, TensorDefault, TensorDefault, 1, 256, 128, 64, 32, 128, 32, 8, 8, 2, 32, 32, 1, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S< 8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 2, S<1, 32, 1, 8>, 8, MaskingSpec> + DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle< NumDimG, NumDimM, NumDimN, NumDimK, NumDimO, BF16, BF16, BF16, BF16, ck::Tuple, ck::Tuple<>, F32, BF16, PassThrough, PassThrough, ScaleAdd, PassThrough, PassThrough, GemmPadded, TensorDefault, TensorDefault, TensorDefault, TensorDefault, 1, 256, 128, 128, 64, 128, 32, 8, 8, 2, 32, 32, 1, 4, 4, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, false, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, false, S< 8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 2, S<1, 32, 1, 8>, 8, MaskingSpec, 1>, + DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle< NumDimG, NumDimM, NumDimN, NumDimK, NumDimO, BF16, BF16, BF16, BF16, ck::Tuple, ck::Tuple<>, F32, BF16, PassThrough, PassThrough, ScaleAdd, PassThrough, PassThrough, GemmPadded, TensorDefault, TensorDefault, TensorDefault, TensorDefault, 1, 256, 128, 128, 64, 128, 32, 8, 8, 2, 32, 32, 1, 4, 4, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, false, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, false, S< 8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 2, S<1, 32, 1, 8>, 8, MaskingSpec>, + DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle< NumDimG, NumDimM, NumDimN, NumDimK, NumDimO, BF16, BF16, BF16, BF16, ck::Tuple, ck::Tuple<>, F32, BF16, PassThrough, PassThrough, ScaleAdd, PassThrough, PassThrough, GemmPadded, TensorDefault, TensorDefault, TensorDefault, TensorDefault, 1, 256, 128, 64, 32, 128, 32, 8, 8, 2, 32, 32, 1, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S< 8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 2, S<1, 32, 1, 8>, 8, MaskingSpec> // clang-format on >; diff --git a/library/src/tensor_operation_instance/gpu/batched_gemm_softmax_gemm_permute/device_batched_gemm_bias_softmax_gemm_permute_xdl_cshuffle_f16_f16_f16_f16_gmk_gnk_gno_gmo_instance.cpp b/library/src/tensor_operation_instance/gpu/batched_gemm_softmax_gemm_permute/device_batched_gemm_bias_softmax_gemm_permute_xdl_cshuffle_f16_f16_f16_f16_gmk_gnk_gno_gmo_instance.cpp index fd094719584147781d4246a508ada7104cdf12e5..744bd6456d28f719a28508f65bdbd78bb02d4120 100644 --- a/library/src/tensor_operation_instance/gpu/batched_gemm_softmax_gemm_permute/device_batched_gemm_bias_softmax_gemm_permute_xdl_cshuffle_f16_f16_f16_f16_gmk_gnk_gno_gmo_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/batched_gemm_softmax_gemm_permute/device_batched_gemm_bias_softmax_gemm_permute_xdl_cshuffle_f16_f16_f16_f16_gmk_gnk_gno_gmo_instance.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #include @@ -41,10 +41,11 @@ template , ck::Tuple<>, F32, F16, PassThrough, PassThrough, ScaleAdd, PassThrough, PassThrough, GemmPadded, TensorDefault, TensorDefault, TensorDefault, TensorDefault, 1, 256, 128, 64, 32, 128, 32, 8, 8, 2, 32, 32, 1, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S< 8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 2, S<1, 32, 1, 8>, 8, MaskingSpec, 1>, DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle< NumDimG, NumDimM, NumDimN, NumDimK, NumDimO, F16, F16, F16, F16, ck::Tuple, ck::Tuple<>, F32, F16, PassThrough, PassThrough, ScaleAdd, PassThrough, PassThrough, GemmDefault, TensorDefault, TensorDefault, TensorDefault, TensorDefault, 1, 256, 256, 128, 32, 64, 32, 8, 8, 2, 32, 32, 2, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<16, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 2, S<1, 32, 1, 8>, 8, MaskingSpec>, DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle< NumDimG, NumDimM, NumDimN, NumDimK, NumDimO, F16, F16, F16, F16, ck::Tuple, ck::Tuple<>, F32, F16, PassThrough, PassThrough, ScaleAdd, PassThrough, PassThrough, GemmDefault, TensorDefault, TensorDefault, TensorDefault, TensorDefault, 1, 256, 256, 128, 32, 128, 32, 8, 8, 2, 32, 32, 2, 4, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S< 8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 2, S<1, 32, 1, 8>, 8, MaskingSpec>, #if CK_WORKAROUND_SWDEV_388832 @@ -60,6 +61,7 @@ using device_batched_gemm_bias_softmax_gemm_permute_xdl_cshuffle_f16_f16_f16_f16 DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle< NumDimG, NumDimM, NumDimN, NumDimK, NumDimO, F16, F16, F16, F16, ck::Tuple, ck::Tuple<>, F32, F16, PassThrough, PassThrough, ScaleAdd, PassThrough, PassThrough, GemmDefault, TensorDefault, TensorDefault, TensorDefault, TensorDefault, 1, 256, 64, 256, 64, 128, 32, 8, 8, 2, 16, 16, 1, 16, 8, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S< 8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 8, S<1, 16, 1,16>, 8, MaskingSpec>, DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle< NumDimG, NumDimM, NumDimN, NumDimK, NumDimO, F16, F16, F16, F16, ck::Tuple, ck::Tuple<>, F32, F16, PassThrough, PassThrough, ScaleAdd, PassThrough, PassThrough, GemmDefault, TensorDefault, TensorDefault, TensorDefault, TensorDefault, 1, 256, 64, 256, 64, 64, 32, 8, 8, 2, 16, 16, 1, 16, 4, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<16, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 4, S<1, 32, 1, 8>, 8, MaskingSpec>, // Padded fallback kernel + DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle< NumDimG, NumDimM, NumDimN, NumDimK, NumDimO, F16, F16, F16, F16, ck::Tuple, ck::Tuple<>, F32, F16, PassThrough, PassThrough, ScaleAdd, PassThrough, PassThrough, GemmPadded, TensorDefault, TensorDefault, TensorDefault, TensorDefault, 1, 256, 128, 128, 64, 128, 32, 8, 8, 2, 32, 32, 1, 4, 4, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, false, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, false, S< 8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 2, S<1, 32, 1, 8>, 8, MaskingSpec, 1>, DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle< NumDimG, NumDimM, NumDimN, NumDimK, NumDimO, F16, F16, F16, F16, ck::Tuple, ck::Tuple<>, F32, F16, PassThrough, PassThrough, ScaleAdd, PassThrough, PassThrough, GemmPadded, TensorDefault, TensorDefault, TensorDefault, TensorDefault, 1, 256, 128, 128, 64, 128, 32, 8, 8, 2, 32, 32, 1, 4, 4, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, false, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, false, S< 8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 2, S<1, 32, 1, 8>, 8, MaskingSpec>, DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle< NumDimG, NumDimM, NumDimN, NumDimK, NumDimO, F16, F16, F16, F16, ck::Tuple, ck::Tuple<>, F32, F16, PassThrough, PassThrough, ScaleAdd, PassThrough, PassThrough, GemmPadded, TensorDefault, TensorDefault, TensorDefault, TensorDefault, 1, 256, 128, 64, 32, 128, 32, 8, 8, 2, 32, 32, 1, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S< 8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 2, S<1, 32, 1, 8>, 8, MaskingSpec> // clang-format on diff --git a/library/src/tensor_operation_instance/gpu/batched_gemm_softmax_gemm_permute/device_batched_gemm_softmax_gemm_permute_xdl_cshuffle_bf16_bf16_bf16_bf16_gmk_gnk_gno_gmo_instance.cpp b/library/src/tensor_operation_instance/gpu/batched_gemm_softmax_gemm_permute/device_batched_gemm_softmax_gemm_permute_xdl_cshuffle_bf16_bf16_bf16_bf16_gmk_gnk_gno_gmo_instance.cpp index 53ad7ba5ffa9272da26189a45a75a9453af68f48..b342612d1c098070108696b7a76a8e7fe754fb51 100644 --- a/library/src/tensor_operation_instance/gpu/batched_gemm_softmax_gemm_permute/device_batched_gemm_softmax_gemm_permute_xdl_cshuffle_bf16_bf16_bf16_bf16_gmk_gnk_gno_gmo_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/batched_gemm_softmax_gemm_permute/device_batched_gemm_softmax_gemm_permute_xdl_cshuffle_bf16_bf16_bf16_bf16_gmk_gnk_gno_gmo_instance.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #include @@ -45,6 +45,7 @@ using device_batched_gemm_softmax_gemm_permute_xdl_cshuffle_bf16_bf16_bf16_bf16_ // #############################################| | | | | | Type| Type| Type| Type| Type| Type| Type| DataType| Elementwise| Elementwise| Elementwise| Elementwise| Elementwise| Specialization| | | | | Prefetch| Size| MPer| NPer| KPer| NPer| KPer| | | | XDL| XDL| MXdl| NXdl| NXdl| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| | // #############################################| | | | | | | | | | | | | | Operation| Operation| Operation| Operation| Operation| | | | | | Stage| | Block| Block| Block| Block| Block| | | | | | Per| Per| Per| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| | // #############################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | Wave| Wave| Wave| | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle< NumDimG, NumDimM, NumDimN, NumDimK, NumDimO, BF16, BF16, BF16, BF16, ck::Tuple<>, ck::Tuple<>, F32, BF16, PassThrough, PassThrough, Scale, PassThrough, PassThrough, GemmPadded, TensorDefault, TensorDefault, TensorDefault, TensorDefault, 1, 256, 128, 64, 32, 128, 32, 8, 8, 2, 32, 32, 1, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S< 8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 2, S<1, 32, 1, 8>, 8, MaskingSpec>, DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle< NumDimG, NumDimM, NumDimN, NumDimK, NumDimO, BF16, BF16, BF16, BF16, ck::Tuple<>, ck::Tuple<>, F32, BF16, PassThrough, PassThrough, Scale, PassThrough, PassThrough, GemmDefault, TensorDefault, TensorDefault, TensorDefault, TensorDefault, 1, 256, 256, 128, 32, 64, 32, 8, 8, 2, 32, 32, 2, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<16, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 2, S<1, 32, 1, 8>, 8, MaskingSpec>, DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle< NumDimG, NumDimM, NumDimN, NumDimK, NumDimO, BF16, BF16, BF16, BF16, ck::Tuple<>, ck::Tuple<>, F32, BF16, PassThrough, PassThrough, Scale, PassThrough, PassThrough, GemmDefault, TensorDefault, TensorDefault, TensorDefault, TensorDefault, 1, 256, 256, 128, 32, 128, 32, 8, 8, 2, 32, 32, 2, 4, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S< 8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 2, S<1, 32, 1, 8>, 8, MaskingSpec>, DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle< NumDimG, NumDimM, NumDimN, NumDimK, NumDimO, BF16, BF16, BF16, BF16, ck::Tuple<>, ck::Tuple<>, F32, BF16, PassThrough, PassThrough, Scale, PassThrough, PassThrough, GemmDefault, TensorDefault, TensorDefault, TensorDefault, TensorDefault, 1, 256, 128, 256, 32, 64, 32, 8, 8, 2, 32, 32, 1, 8, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<16, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 2, S<1, 32, 1, 8>, 8, MaskingSpec>, @@ -58,8 +59,7 @@ using device_batched_gemm_softmax_gemm_permute_xdl_cshuffle_bf16_bf16_bf16_bf16_ DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle< NumDimG, NumDimM, NumDimN, NumDimK, NumDimO, BF16, BF16, BF16, BF16, ck::Tuple<>, ck::Tuple<>, F32, BF16, PassThrough, PassThrough, Scale, PassThrough, PassThrough, GemmDefault, TensorDefault, TensorDefault, TensorDefault, TensorDefault, 1, 256, 64, 256, 64, 128, 32, 8, 8, 2, 16, 16, 1, 16, 8, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S< 8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 8, S<1, 16, 1,16>, 8, MaskingSpec>, DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle< NumDimG, NumDimM, NumDimN, NumDimK, NumDimO, BF16, BF16, BF16, BF16, ck::Tuple<>, ck::Tuple<>, F32, BF16, PassThrough, PassThrough, Scale, PassThrough, PassThrough, GemmDefault, TensorDefault, TensorDefault, TensorDefault, TensorDefault, 1, 256, 64, 256, 64, 64, 32, 8, 8, 2, 16, 16, 1, 16, 4, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<16, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 4, S<1, 32, 1, 8>, 8, MaskingSpec>, // Padded fallback kernel - DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle< NumDimG, NumDimM, NumDimN, NumDimK, NumDimO, BF16, BF16, BF16, BF16, ck::Tuple<>, ck::Tuple<>, F32, BF16, PassThrough, PassThrough, Scale, PassThrough, PassThrough, GemmPadded, TensorDefault, TensorDefault, TensorDefault, TensorDefault, 1, 256, 128, 128, 64, 128, 32, 8, 8, 2, 32, 32, 1, 4, 4, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, false, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, false, S< 8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 2, S<1, 32, 1, 8>, 8, MaskingSpec>, - DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle< NumDimG, NumDimM, NumDimN, NumDimK, NumDimO, BF16, BF16, BF16, BF16, ck::Tuple<>, ck::Tuple<>, F32, BF16, PassThrough, PassThrough, Scale, PassThrough, PassThrough, GemmPadded, TensorDefault, TensorDefault, TensorDefault, TensorDefault, 1, 256, 128, 64, 32, 128, 32, 8, 8, 2, 32, 32, 1, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S< 8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 2, S<1, 32, 1, 8>, 8, MaskingSpec> + DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle< NumDimG, NumDimM, NumDimN, NumDimK, NumDimO, BF16, BF16, BF16, BF16, ck::Tuple<>, ck::Tuple<>, F32, BF16, PassThrough, PassThrough, Scale, PassThrough, PassThrough, GemmPadded, TensorDefault, TensorDefault, TensorDefault, TensorDefault, 1, 256, 128, 128, 64, 128, 32, 8, 8, 2, 32, 32, 1, 4, 4, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, false, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, false, S< 8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 2, S<1, 32, 1, 8>, 8, MaskingSpec> // clang-format on >; diff --git a/library/src/tensor_operation_instance/gpu/batched_gemm_softmax_gemm_permute/device_batched_gemm_softmax_gemm_permute_xdl_cshuffle_f16_f16_f16_f16_gmk_gnk_gno_gmo_instance.cpp b/library/src/tensor_operation_instance/gpu/batched_gemm_softmax_gemm_permute/device_batched_gemm_softmax_gemm_permute_xdl_cshuffle_f16_f16_f16_f16_gmk_gnk_gno_gmo_instance.cpp index 291a127a6603bd5fe1256076355df76d133978e9..3fd0c07370d40407dd3ef754f55ccff9deb091df 100644 --- a/library/src/tensor_operation_instance/gpu/batched_gemm_softmax_gemm_permute/device_batched_gemm_softmax_gemm_permute_xdl_cshuffle_f16_f16_f16_f16_gmk_gnk_gno_gmo_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/batched_gemm_softmax_gemm_permute/device_batched_gemm_softmax_gemm_permute_xdl_cshuffle_f16_f16_f16_f16_gmk_gnk_gno_gmo_instance.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #include @@ -45,6 +45,7 @@ using device_batched_gemm_softmax_gemm_permute_xdl_cshuffle_f16_f16_f16_f16_gmk_ // #############################################| | | | | | Type| Type| Type| Type| Type| Type| Type| DataType| Elementwise| Elementwise| Elementwise| Elementwise| Elementwise| Specialization| | | | | Prefetch| Size| MPer| NPer| KPer| NPer| KPer| | | | XDL| XDL| MXdl| NXdl| NXdl| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| | // #############################################| | | | | | | | | | | | | | Operation| Operation| Operation| Operation| Operation| | | | | | Stage| | Block| Block| Block| Block| Block| | | | | | Per| Per| Per| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| | // #############################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | Wave| Wave| Wave| | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle< NumDimG, NumDimM, NumDimN, NumDimK, NumDimO, F16, F16, F16, F16, ck::Tuple<>, ck::Tuple<>, F32, F16, PassThrough, PassThrough, Scale, PassThrough, PassThrough, GemmPadded, TensorDefault, TensorDefault, TensorDefault, TensorDefault, 1, 256, 128, 64, 32, 128, 32, 8, 8, 2, 32, 32, 1, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S< 8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 2, S<1, 32, 1, 8>, 8, MaskingSpec>, DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle< NumDimG, NumDimM, NumDimN, NumDimK, NumDimO, F16, F16, F16, F16, ck::Tuple<>, ck::Tuple<>, F32, F16, PassThrough, PassThrough, Scale, PassThrough, PassThrough, GemmDefault, TensorDefault, TensorDefault, TensorDefault, TensorDefault, 1, 256, 256, 128, 32, 64, 32, 8, 8, 2, 32, 32, 2, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<16, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 2, S<1, 32, 1, 8>, 8, MaskingSpec>, DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle< NumDimG, NumDimM, NumDimN, NumDimK, NumDimO, F16, F16, F16, F16, ck::Tuple<>, ck::Tuple<>, F32, F16, PassThrough, PassThrough, Scale, PassThrough, PassThrough, GemmDefault, TensorDefault, TensorDefault, TensorDefault, TensorDefault, 1, 256, 256, 128, 32, 128, 32, 8, 8, 2, 32, 32, 2, 4, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S< 8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 2, S<1, 32, 1, 8>, 8, MaskingSpec>, #if CK_WORKAROUND_SWDEV_388832 @@ -60,8 +61,7 @@ using device_batched_gemm_softmax_gemm_permute_xdl_cshuffle_f16_f16_f16_f16_gmk_ DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle< NumDimG, NumDimM, NumDimN, NumDimK, NumDimO, F16, F16, F16, F16, ck::Tuple<>, ck::Tuple<>, F32, F16, PassThrough, PassThrough, Scale, PassThrough, PassThrough, GemmDefault, TensorDefault, TensorDefault, TensorDefault, TensorDefault, 1, 256, 64, 256, 64, 128, 32, 8, 8, 2, 16, 16, 1, 16, 8, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S< 8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 8, S<1, 16, 1,16>, 8, MaskingSpec>, DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle< NumDimG, NumDimM, NumDimN, NumDimK, NumDimO, F16, F16, F16, F16, ck::Tuple<>, ck::Tuple<>, F32, F16, PassThrough, PassThrough, Scale, PassThrough, PassThrough, GemmDefault, TensorDefault, TensorDefault, TensorDefault, TensorDefault, 1, 256, 64, 256, 64, 64, 32, 8, 8, 2, 16, 16, 1, 16, 4, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<16, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 4, S<1, 32, 1, 8>, 8, MaskingSpec>, // Padded fallback kernel - DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle< NumDimG, NumDimM, NumDimN, NumDimK, NumDimO, F16, F16, F16, F16, ck::Tuple<>, ck::Tuple<>, F32, F16, PassThrough, PassThrough, Scale, PassThrough, PassThrough, GemmPadded, TensorDefault, TensorDefault, TensorDefault, TensorDefault, 1, 256, 128, 128, 64, 128, 32, 8, 8, 2, 32, 32, 1, 4, 4, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, false, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, false, S< 8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 2, S<1, 32, 1, 8>, 8, MaskingSpec>, - DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle< NumDimG, NumDimM, NumDimN, NumDimK, NumDimO, F16, F16, F16, F16, ck::Tuple<>, ck::Tuple<>, F32, F16, PassThrough, PassThrough, Scale, PassThrough, PassThrough, GemmPadded, TensorDefault, TensorDefault, TensorDefault, TensorDefault, 1, 256, 128, 64, 32, 128, 32, 8, 8, 2, 32, 32, 1, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S< 8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 2, S<1, 32, 1, 8>, 8, MaskingSpec> + DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle< NumDimG, NumDimM, NumDimN, NumDimK, NumDimO, F16, F16, F16, F16, ck::Tuple<>, ck::Tuple<>, F32, F16, PassThrough, PassThrough, Scale, PassThrough, PassThrough, GemmPadded, TensorDefault, TensorDefault, TensorDefault, TensorDefault, 1, 256, 128, 128, 64, 128, 32, 8, 8, 2, 32, 32, 1, 4, 4, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, false, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, false, S< 8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 2, S<1, 32, 1, 8>, 8, MaskingSpec> // clang-format on >; diff --git a/library/src/tensor_operation_instance/gpu/batchnorm/device_batchnorm_backward_bf16_instance.cpp b/library/src/tensor_operation_instance/gpu/batchnorm/device_batchnorm_backward_bf16_instance.cpp index b62c8b99cbdb6cf42922e843c11a0832614b6ea7..bc95d2f1b1f1ad3734daaa19ed5dbb1faff5a6a4 100644 --- a/library/src/tensor_operation_instance/gpu/batchnorm/device_batchnorm_backward_bf16_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/batchnorm/device_batchnorm_backward_bf16_instance.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #include "ck/ck.hpp" #include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" diff --git a/library/src/tensor_operation_instance/gpu/batchnorm/device_batchnorm_backward_f16_instance.cpp b/library/src/tensor_operation_instance/gpu/batchnorm/device_batchnorm_backward_f16_instance.cpp index d05b8b592c29aceef387a4fdec8c02854501995f..fbc8d0bc6097a477fe9920201167a58142aa9b44 100644 --- a/library/src/tensor_operation_instance/gpu/batchnorm/device_batchnorm_backward_f16_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/batchnorm/device_batchnorm_backward_f16_instance.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #include "ck/ck.hpp" #include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" diff --git a/library/src/tensor_operation_instance/gpu/batchnorm/device_batchnorm_backward_f32_instance.cpp b/library/src/tensor_operation_instance/gpu/batchnorm/device_batchnorm_backward_f32_instance.cpp index e3ef95d12e17acfebc029ac6e1a878194691739f..bed38658a99c109370e6271dc94f785d94268106 100644 --- a/library/src/tensor_operation_instance/gpu/batchnorm/device_batchnorm_backward_f32_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/batchnorm/device_batchnorm_backward_f32_instance.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #include "ck/ck.hpp" #include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" diff --git a/library/src/tensor_operation_instance/gpu/batchnorm/device_batchnorm_backward_f64_instance.cpp b/library/src/tensor_operation_instance/gpu/batchnorm/device_batchnorm_backward_f64_instance.cpp index 41be396c24a28b5a7fbc47b10d7cf6ae5e6e07e9..fc5ec77e428dfbfb09a902245f27ca25b8bc0d35 100644 --- a/library/src/tensor_operation_instance/gpu/batchnorm/device_batchnorm_backward_f64_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/batchnorm/device_batchnorm_backward_f64_instance.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #include "ck/ck.hpp" #include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" diff --git a/library/src/tensor_operation_instance/gpu/batchnorm/device_batchnorm_forward_bf16_instance.cpp b/library/src/tensor_operation_instance/gpu/batchnorm/device_batchnorm_forward_bf16_instance.cpp index cd1e05b1133ad2a2515a936ef968667a03bc11be..4e38ee13b24dd150f1366ed7694e69c6969a2e2c 100644 --- a/library/src/tensor_operation_instance/gpu/batchnorm/device_batchnorm_forward_bf16_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/batchnorm/device_batchnorm_forward_bf16_instance.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #include "ck/ck.hpp" #include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" diff --git a/library/src/tensor_operation_instance/gpu/batchnorm/device_batchnorm_forward_f16_instance.cpp b/library/src/tensor_operation_instance/gpu/batchnorm/device_batchnorm_forward_f16_instance.cpp index 073dd583f97e0eed167e4bc4ef952530e82854a2..f087eb79824e3e36c6bc07a3d42b68f38a89c16b 100644 --- a/library/src/tensor_operation_instance/gpu/batchnorm/device_batchnorm_forward_f16_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/batchnorm/device_batchnorm_forward_f16_instance.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #include "ck/ck.hpp" #include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" diff --git a/library/src/tensor_operation_instance/gpu/batchnorm/device_batchnorm_forward_f32_instance.cpp b/library/src/tensor_operation_instance/gpu/batchnorm/device_batchnorm_forward_f32_instance.cpp index be63bd44c66c3204c66990ff6048bd895841c84f..d0f361401a500c16f1228d4df1d3de6714d1f43a 100644 --- a/library/src/tensor_operation_instance/gpu/batchnorm/device_batchnorm_forward_f32_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/batchnorm/device_batchnorm_forward_f32_instance.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #include "ck/ck.hpp" #include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" diff --git a/library/src/tensor_operation_instance/gpu/batchnorm/device_batchnorm_forward_f64_instance.cpp b/library/src/tensor_operation_instance/gpu/batchnorm/device_batchnorm_forward_f64_instance.cpp index fe87091e8d66e32af0832c1b013abce0911673b2..710d07b8280cef467d2fac8a70034129feb13fda 100644 --- a/library/src/tensor_operation_instance/gpu/batchnorm/device_batchnorm_forward_f64_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/batchnorm/device_batchnorm_forward_f64_instance.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #include "ck/ck.hpp" #include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" diff --git a/library/src/tensor_operation_instance/gpu/batchnorm/device_batchnorm_infer_bf16_instance.cpp b/library/src/tensor_operation_instance/gpu/batchnorm/device_batchnorm_infer_bf16_instance.cpp index 2e695afa9786d2e4fdcc6d9bf05c600c1a333181..8801c309f9138ccf321ba1dcac5f1718287e0a76 100644 --- a/library/src/tensor_operation_instance/gpu/batchnorm/device_batchnorm_infer_bf16_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/batchnorm/device_batchnorm_infer_bf16_instance.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #include "ck/ck.hpp" #include "ck/utility/tuple.hpp" diff --git a/library/src/tensor_operation_instance/gpu/batchnorm/device_batchnorm_infer_f16_instance.cpp b/library/src/tensor_operation_instance/gpu/batchnorm/device_batchnorm_infer_f16_instance.cpp index 9ec761e445ac30cb482c35be42846c3e46a13ea7..b674cfc423d81b5efce31169e1c81741ba156b9a 100644 --- a/library/src/tensor_operation_instance/gpu/batchnorm/device_batchnorm_infer_f16_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/batchnorm/device_batchnorm_infer_f16_instance.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #include "ck/ck.hpp" #include "ck/utility/tuple.hpp" diff --git a/library/src/tensor_operation_instance/gpu/batchnorm/device_batchnorm_infer_f32_instance.cpp b/library/src/tensor_operation_instance/gpu/batchnorm/device_batchnorm_infer_f32_instance.cpp index f0d26c36bedfa037a439e6226e8b1db4051d0fc3..05e3650887c3dfa84d1f03a8380d1144c7886acf 100644 --- a/library/src/tensor_operation_instance/gpu/batchnorm/device_batchnorm_infer_f32_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/batchnorm/device_batchnorm_infer_f32_instance.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #include "ck/ck.hpp" #include "ck/utility/tuple.hpp" diff --git a/library/src/tensor_operation_instance/gpu/batchnorm/device_batchnorm_infer_f64_instance.cpp b/library/src/tensor_operation_instance/gpu/batchnorm/device_batchnorm_infer_f64_instance.cpp index 9e4066bb060e0a58342956860356ee589bd5e9c0..15a02af021a22b4e9e43186a7b3e18b938534dd8 100644 --- a/library/src/tensor_operation_instance/gpu/batchnorm/device_batchnorm_infer_f64_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/batchnorm/device_batchnorm_infer_f64_instance.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #include "ck/ck.hpp" #include "ck/utility/tuple.hpp" diff --git a/library/src/tensor_operation_instance/gpu/contraction_bilinear/CMakeLists.txt b/library/src/tensor_operation_instance/gpu/contraction_bilinear/CMakeLists.txt index d2a0a3d0fbe7a869b2d1603fe8879a201c1ca4b5..1db6985f61606031756f6b056cad92eda32b1802 100644 --- a/library/src/tensor_operation_instance/gpu/contraction_bilinear/CMakeLists.txt +++ b/library/src/tensor_operation_instance/gpu/contraction_bilinear/CMakeLists.txt @@ -1,13 +1,17 @@ -add_instance_library(device_contraction_bilinear_instance +set(DEVICE_CONTRACTION_BILINEAR_INSTANCES) +if(DTYPES MATCHES "fp32" OR NOT DEFINED DTYPES) #float - device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_f32_kknn_instance.cpp - device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_f32_knnn_instance.cpp - device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_f32_mknn_instance.cpp - device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_f32_mnnn_instance.cpp + list(APPEND DEVICE_CONTRACTION_BILINEAR_INSTANCES device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_f32_kknn_instance.cpp + device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_f32_knnn_instance.cpp + device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_f32_mknn_instance.cpp + device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_f32_mnnn_instance.cpp) +endif() +if(DTYPES MATCHES "fp64" OR NOT DEFINED DTYPES) #double - device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f64_f64_f64_f64_kknn_instance.cpp - device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f64_f64_f64_f64_knnn_instance.cpp - device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f64_f64_f64_f64_mknn_instance.cpp - device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f64_f64_f64_f64_mnnn_instance.cpp -) + list(APPEND DEVICE_CONTRACTION_BILINEAR_INSTANCES device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f64_f64_f64_f64_kknn_instance.cpp + device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f64_f64_f64_f64_knnn_instance.cpp + device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f64_f64_f64_f64_mknn_instance.cpp + device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f64_f64_f64_f64_mnnn_instance.cpp) +endif() +add_instance_library(device_contraction_bilinear_instance ${DEVICE_CONTRACTION_BILINEAR_INSTANCES}) diff --git a/library/src/tensor_operation_instance/gpu/contraction_bilinear/device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_f32_kknn_instance.cpp b/library/src/tensor_operation_instance/gpu/contraction_bilinear/device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_f32_kknn_instance.cpp index ebbff88346b9dcb1eb4716b1b36f456b806ef1ca..5587db77e08ceb40242084853ed07236cc992605 100644 --- a/library/src/tensor_operation_instance/gpu/contraction_bilinear/device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_f32_kknn_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/contraction_bilinear/device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_f32_kknn_instance.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. // This (ifndef) is a hack to use customized behavior for buffer load rather than using default // setting Don't use this hack unless absolutely necessary! diff --git a/library/src/tensor_operation_instance/gpu/contraction_bilinear/device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_f32_knnn_instance.cpp b/library/src/tensor_operation_instance/gpu/contraction_bilinear/device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_f32_knnn_instance.cpp index 980383f3e71b3f1058c295c9aba55c56f596134a..26262855ea8b6e3ba2c929e527e97dc5108495dd 100644 --- a/library/src/tensor_operation_instance/gpu/contraction_bilinear/device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_f32_knnn_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/contraction_bilinear/device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_f32_knnn_instance.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. // This (ifndef) is a hack to use customized behavior for buffer load rather than using default // setting Don't use this hack unless absolutely necessary! diff --git a/library/src/tensor_operation_instance/gpu/contraction_bilinear/device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_f32_mknn_instance.cpp b/library/src/tensor_operation_instance/gpu/contraction_bilinear/device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_f32_mknn_instance.cpp index 2d4b6e3489a9ccce619b0a76dc097f04d72b0d57..befc0dcd10d07735ec046e8ef4d64bb54be4756c 100644 --- a/library/src/tensor_operation_instance/gpu/contraction_bilinear/device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_f32_mknn_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/contraction_bilinear/device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_f32_mknn_instance.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. // This (ifndef) is a hack to use customized behavior for buffer load rather than using default // setting Don't use this hack unless absolutely necessary! diff --git a/library/src/tensor_operation_instance/gpu/contraction_bilinear/device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_f32_mnnn_instance.cpp b/library/src/tensor_operation_instance/gpu/contraction_bilinear/device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_f32_mnnn_instance.cpp index 7caa469f54b0169449250f000df038820e35a32f..e45b47cf94ed53de17b862e50ed2ef7a00e3e978 100644 --- a/library/src/tensor_operation_instance/gpu/contraction_bilinear/device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_f32_mnnn_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/contraction_bilinear/device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_f32_mnnn_instance.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. // This (ifndef) is a hack to use customized behavior for buffer load rather than using default // setting Don't use this hack unless absolutely necessary! diff --git a/library/src/tensor_operation_instance/gpu/contraction_bilinear/device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f64_f64_f64_f64_kknn_instance.cpp b/library/src/tensor_operation_instance/gpu/contraction_bilinear/device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f64_f64_f64_f64_kknn_instance.cpp index 093b2f0e98bb5e12d867bcdd58c4817e2cb33111..f437a227d50b665b999e30b2d065ac6825e7ec04 100644 --- a/library/src/tensor_operation_instance/gpu/contraction_bilinear/device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f64_f64_f64_f64_kknn_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/contraction_bilinear/device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f64_f64_f64_f64_kknn_instance.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. // This (ifndef) is a hack to use customized behavior for buffer load rather than using default // setting Don't use this hack unless absolutely necessary! diff --git a/library/src/tensor_operation_instance/gpu/contraction_bilinear/device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f64_f64_f64_f64_knnn_instance.cpp b/library/src/tensor_operation_instance/gpu/contraction_bilinear/device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f64_f64_f64_f64_knnn_instance.cpp index 0f683e5c280306b8c14b461a4d46183055094afa..13fdbeb35c53a93947f739dd26f91885cff4a72b 100644 --- a/library/src/tensor_operation_instance/gpu/contraction_bilinear/device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f64_f64_f64_f64_knnn_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/contraction_bilinear/device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f64_f64_f64_f64_knnn_instance.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. // This (ifndef) is a hack to use customized behavior for buffer load rather than using default // setting Don't use this hack unless absolutely necessary! diff --git a/library/src/tensor_operation_instance/gpu/contraction_bilinear/device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f64_f64_f64_f64_mknn_instance.cpp b/library/src/tensor_operation_instance/gpu/contraction_bilinear/device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f64_f64_f64_f64_mknn_instance.cpp index e384993aed76e58f11238cd2d5bad3bb4515e457..95ef8c4929b00263c65034ac1ec48195efdc8fff 100644 --- a/library/src/tensor_operation_instance/gpu/contraction_bilinear/device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f64_f64_f64_f64_mknn_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/contraction_bilinear/device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f64_f64_f64_f64_mknn_instance.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. // This (ifndef) is a hack to use customized behavior for buffer load rather than using default // setting Don't use this hack unless absolutely necessary! diff --git a/library/src/tensor_operation_instance/gpu/contraction_bilinear/device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f64_f64_f64_f64_mnnn_instance.cpp b/library/src/tensor_operation_instance/gpu/contraction_bilinear/device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f64_f64_f64_f64_mnnn_instance.cpp index 92e39c173f55cb30d730415d44eda4eb4c1c3b32..290f81d7c9a926dd4e9c169e3515ebdbaa547788 100644 --- a/library/src/tensor_operation_instance/gpu/contraction_bilinear/device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f64_f64_f64_f64_mnnn_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/contraction_bilinear/device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f64_f64_f64_f64_mnnn_instance.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. // This (ifndef) is a hack to use customized behavior for buffer load rather than using default // setting Don't use this hack unless absolutely necessary! diff --git a/library/src/tensor_operation_instance/gpu/contraction_scale/CMakeLists.txt b/library/src/tensor_operation_instance/gpu/contraction_scale/CMakeLists.txt index 31f6a0fcdc9a04d15d63ffc93e2dac550a0ed293..aed8bef2a867a009ac69503be5bfa5a4a8f45ccd 100644 --- a/library/src/tensor_operation_instance/gpu/contraction_scale/CMakeLists.txt +++ b/library/src/tensor_operation_instance/gpu/contraction_scale/CMakeLists.txt @@ -1,13 +1,17 @@ -add_instance_library(device_contraction_scale_instance +set(DEVICE_CONTRACTION_SCALE_INSTANCES) +if(DTYPES MATCHES "fp32" OR NOT DEFINED DTYPES) #float - device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_kkn_instance.cpp - device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_knn_instance.cpp - device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_mkn_instance.cpp - device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_mnn_instance.cpp + list(APPEND DEVICE_CONTRACTION_SCALE_INSTANCES device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_kkn_instance.cpp + device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_knn_instance.cpp + device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_mkn_instance.cpp + device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_mnn_instance.cpp) +endif() +if(DTYPES MATCHES "fp64" OR NOT DEFINED DTYPES) #double - device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f64_f64_f64_kkn_instance.cpp - device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f64_f64_f64_knn_instance.cpp - device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f64_f64_f64_mkn_instance.cpp - device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f64_f64_f64_mnn_instance.cpp -) + list(APPEND DEVICE_CONTRACTION_SCALE_INSTANCES device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f64_f64_f64_kkn_instance.cpp + device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f64_f64_f64_knn_instance.cpp + device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f64_f64_f64_mkn_instance.cpp + device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f64_f64_f64_mnn_instance.cpp) +endif() +add_instance_library(device_contraction_scale_instance ${DEVICE_CONTRACTION_SCALE_INSTANCES}) diff --git a/library/src/tensor_operation_instance/gpu/contraction_scale/device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_kkn_instance.cpp b/library/src/tensor_operation_instance/gpu/contraction_scale/device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_kkn_instance.cpp index 5118d0d033ee733c9b3bd7bc3524824c77aae667..16fd1cb407bf81741845c41baa580f208e98508d 100644 --- a/library/src/tensor_operation_instance/gpu/contraction_scale/device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_kkn_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/contraction_scale/device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_kkn_instance.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. // This (ifndef) is a hack to use customized behavior for buffer load rather than using default // setting Don't use this hack unless absolutely necessary! diff --git a/library/src/tensor_operation_instance/gpu/contraction_scale/device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_knn_instance.cpp b/library/src/tensor_operation_instance/gpu/contraction_scale/device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_knn_instance.cpp index 655d4f0061ae1c417e610e9059f7e2bbd2048363..ff37bf7cceb40f75c8a54309341914539057d678 100644 --- a/library/src/tensor_operation_instance/gpu/contraction_scale/device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_knn_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/contraction_scale/device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_knn_instance.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. // This (ifndef) is a hack to use customized behavior for buffer load rather than using default // setting Don't use this hack unless absolutely necessary! diff --git a/library/src/tensor_operation_instance/gpu/contraction_scale/device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_mkn_instance.cpp b/library/src/tensor_operation_instance/gpu/contraction_scale/device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_mkn_instance.cpp index a9d20be18bf5c52044b12f84c58f7824f8ef8142..8a1f6f93341edbb585157cd66308ab8f781bfedc 100644 --- a/library/src/tensor_operation_instance/gpu/contraction_scale/device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_mkn_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/contraction_scale/device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_mkn_instance.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. // This (ifndef) is a hack to use customized behavior for buffer load rather than using default // setting Don't use this hack unless absolutely necessary! diff --git a/library/src/tensor_operation_instance/gpu/contraction_scale/device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_mnn_instance.cpp b/library/src/tensor_operation_instance/gpu/contraction_scale/device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_mnn_instance.cpp index a68f5c9718a1e547c62f16888a4fd0a0ed1d8197..d333f597268d55a195da671fcbcdaab2c03acfc0 100644 --- a/library/src/tensor_operation_instance/gpu/contraction_scale/device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_mnn_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/contraction_scale/device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_mnn_instance.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. // This (ifndef) is a hack to use customized behavior for buffer load rather than using default // setting Don't use this hack unless absolutely necessary! diff --git a/library/src/tensor_operation_instance/gpu/contraction_scale/device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f64_f64_f64_kkn_instance.cpp b/library/src/tensor_operation_instance/gpu/contraction_scale/device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f64_f64_f64_kkn_instance.cpp index 0aa927155a412fede87f08fde619884f3aaa3405..4c87b51a9e426945f049d471bf3591d1f222b72b 100644 --- a/library/src/tensor_operation_instance/gpu/contraction_scale/device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f64_f64_f64_kkn_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/contraction_scale/device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f64_f64_f64_kkn_instance.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. // This (ifndef) is a hack to use customized behavior for buffer load rather than using default // setting Don't use this hack unless absolutely necessary! diff --git a/library/src/tensor_operation_instance/gpu/contraction_scale/device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f64_f64_f64_knn_instance.cpp b/library/src/tensor_operation_instance/gpu/contraction_scale/device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f64_f64_f64_knn_instance.cpp index b84ea274c5022c173d3a396661bf78396b07619f..fd3f57c6b6058f16e9e34158fefc71e3fb9c5f3c 100644 --- a/library/src/tensor_operation_instance/gpu/contraction_scale/device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f64_f64_f64_knn_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/contraction_scale/device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f64_f64_f64_knn_instance.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. // This (ifndef) is a hack to use customized behavior for buffer load rather than using default // setting Don't use this hack unless absolutely necessary! diff --git a/library/src/tensor_operation_instance/gpu/contraction_scale/device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f64_f64_f64_mkn_instance.cpp b/library/src/tensor_operation_instance/gpu/contraction_scale/device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f64_f64_f64_mkn_instance.cpp index 578469997acf056af0ade6a6d92872aad9d81b3c..1e53f0b2f51f23143b1d49d1b7d9b2d3e2b7b765 100644 --- a/library/src/tensor_operation_instance/gpu/contraction_scale/device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f64_f64_f64_mkn_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/contraction_scale/device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f64_f64_f64_mkn_instance.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. // This (ifndef) is a hack to use customized behavior for buffer load rather than using default // setting Don't use this hack unless absolutely necessary! diff --git a/library/src/tensor_operation_instance/gpu/contraction_scale/device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f64_f64_f64_mnn_instance.cpp b/library/src/tensor_operation_instance/gpu/contraction_scale/device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f64_f64_f64_mnn_instance.cpp index 8e5a19313ee4262668dc4481139b6284ca898538..d02d146a9bf98b25fe7de56148c8879dd62a328b 100644 --- a/library/src/tensor_operation_instance/gpu/contraction_scale/device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f64_f64_f64_mnn_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/contraction_scale/device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f64_f64_f64_mnn_instance.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. // This (ifndef) is a hack to use customized behavior for buffer load rather than using default // setting Don't use this hack unless absolutely necessary! diff --git a/library/src/tensor_operation_instance/gpu/conv1d_bwd_data/device_conv1d_bwd_data_xdl_nwc_kxc_nwk_bf16_instance.cpp b/library/src/tensor_operation_instance/gpu/conv1d_bwd_data/device_conv1d_bwd_data_xdl_nwc_kxc_nwk_bf16_instance.cpp index 5a5c8384227150dd8717433dce6c9ea2395896bd..e3e90c966dd6efd920ea0147d434f926bd07bd05 100644 --- a/library/src/tensor_operation_instance/gpu/conv1d_bwd_data/device_conv1d_bwd_data_xdl_nwc_kxc_nwk_bf16_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/conv1d_bwd_data/device_conv1d_bwd_data_xdl_nwc_kxc_nwk_bf16_instance.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #include diff --git a/library/src/tensor_operation_instance/gpu/conv1d_bwd_data/device_conv1d_bwd_data_xdl_nwc_kxc_nwk_f16_instance.cpp b/library/src/tensor_operation_instance/gpu/conv1d_bwd_data/device_conv1d_bwd_data_xdl_nwc_kxc_nwk_f16_instance.cpp index e0f3d6199f8939bb0f1ce7f8efc48f2a92cfd3c6..81e9122d9506eb3ae78602c7beedbd2ce1b0b745 100644 --- a/library/src/tensor_operation_instance/gpu/conv1d_bwd_data/device_conv1d_bwd_data_xdl_nwc_kxc_nwk_f16_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/conv1d_bwd_data/device_conv1d_bwd_data_xdl_nwc_kxc_nwk_f16_instance.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #include diff --git a/library/src/tensor_operation_instance/gpu/conv1d_bwd_data/device_conv1d_bwd_data_xdl_nwc_kxc_nwk_f32_instance.cpp b/library/src/tensor_operation_instance/gpu/conv1d_bwd_data/device_conv1d_bwd_data_xdl_nwc_kxc_nwk_f32_instance.cpp index 30537d9373b8175207b267666574033cae4e1a3c..dbc82168f49ad102886ce8aebaea72cd3ef5a99a 100644 --- a/library/src/tensor_operation_instance/gpu/conv1d_bwd_data/device_conv1d_bwd_data_xdl_nwc_kxc_nwk_f32_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/conv1d_bwd_data/device_conv1d_bwd_data_xdl_nwc_kxc_nwk_f32_instance.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #include diff --git a/library/src/tensor_operation_instance/gpu/conv1d_bwd_data/device_conv1d_bwd_data_xdl_nwc_kxc_nwk_int8_instance.cpp b/library/src/tensor_operation_instance/gpu/conv1d_bwd_data/device_conv1d_bwd_data_xdl_nwc_kxc_nwk_int8_instance.cpp index 190c39b870b3cb33e9080c8b591884f4ec387347..3ac250f3e6a9d3fe8b3139aef5a79dc520269381 100644 --- a/library/src/tensor_operation_instance/gpu/conv1d_bwd_data/device_conv1d_bwd_data_xdl_nwc_kxc_nwk_int8_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/conv1d_bwd_data/device_conv1d_bwd_data_xdl_nwc_kxc_nwk_int8_instance.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #include diff --git a/library/src/tensor_operation_instance/gpu/conv2d_bwd_data/CMakeLists.txt b/library/src/tensor_operation_instance/gpu/conv2d_bwd_data/CMakeLists.txt index 281453b586ba9b635063a3b51642066450c253a5..54ef9cc7a6bd929ce6f3012acbf5821e2d1e318f 100644 --- a/library/src/tensor_operation_instance/gpu/conv2d_bwd_data/CMakeLists.txt +++ b/library/src/tensor_operation_instance/gpu/conv2d_bwd_data/CMakeLists.txt @@ -1,10 +1,23 @@ -add_instance_library(device_conv2d_bwd_data_instance - device_conv2d_bwd_data_xdl_nhwc_kyxc_nhwk_f32_instance.cpp - device_conv2d_bwd_data_xdl_nhwc_kyxc_nhwk_f16_instance.cpp - device_conv2d_bwd_data_xdl_nhwc_kyxc_nhwk_bf16_instance.cpp - device_conv2d_bwd_data_xdl_nhwc_kyxc_nhwk_int8_instance.cpp - - device_conv2d_bwd_data_dl_nhwc_kyxc_nhwk_f16_instance.cpp - device_conv2d_bwd_data_dl_nhwc_kyxc_nhwk_f32_instance.cpp - device_conv2d_bwd_data_dl_nhwc_kyxc_nhwk_int8_instance.cpp -) +set(CONV2D_BWD_DATA_INSTANCES) +if(DTYPES MATCHES "fp32" OR NOT DEFINED DTYPES) + list(APPEND CONV2D_BWD_DATA_INSTANCES device_conv2d_bwd_data_xdl_nhwc_kyxc_nhwk_f32_instance.cpp) + if(DL_KERNELS) + list(APPEND CONV2D_BWD_DATA_INSTANCES device_conv2d_bwd_data_dl_nhwc_kyxc_nhwk_f32_instance.cpp) + endif() +endif() +if(DTYPES MATCHES "bf16" OR NOT DEFINED DTYPES) + list(APPEND CONV2D_BWD_DATA_INSTANCES device_conv2d_bwd_data_xdl_nhwc_kyxc_nhwk_bf16_instance.cpp) +endif() +if(DTYPES MATCHES "fp16" OR NOT DEFINED DTYPES) + list(APPEND CONV2D_BWD_DATA_INSTANCES device_conv2d_bwd_data_xdl_nhwc_kyxc_nhwk_f16_instance.cpp) + if(DL_KERNELS) + list(APPEND CONV2D_BWD_DATA_INSTANCES device_conv2d_bwd_data_dl_nhwc_kyxc_nhwk_f16_instance.cpp) + endif() +endif() +if(DTYPES MATCHES "int8" OR NOT DEFINED DTYPES) + list(APPEND CONV2D_BWD_DATA_INSTANCES device_conv2d_bwd_data_xdl_nhwc_kyxc_nhwk_int8_instance.cpp) + if(DL_KERNELS) + list(APPEND CONV2D_BWD_DATA_INSTANCES device_conv2d_bwd_data_dl_nhwc_kyxc_nhwk_int8_instance.cpp) + endif() +endif() +add_instance_library(device_conv2d_bwd_data_instance ${CONV2D_BWD_DATA_INSTANCES}) diff --git a/library/src/tensor_operation_instance/gpu/conv2d_bwd_data/device_conv2d_bwd_data_dl_nhwc_kyxc_nhwk_f16_instance.cpp b/library/src/tensor_operation_instance/gpu/conv2d_bwd_data/device_conv2d_bwd_data_dl_nhwc_kyxc_nhwk_f16_instance.cpp index e14cd558628772991da9b9d4ebb3dd5d852e3285..6ca909c35edd59e55f9d8e4056db5bad3bbbd140 100644 --- a/library/src/tensor_operation_instance/gpu/conv2d_bwd_data/device_conv2d_bwd_data_dl_nhwc_kyxc_nhwk_f16_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/conv2d_bwd_data/device_conv2d_bwd_data_dl_nhwc_kyxc_nhwk_f16_instance.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #include @@ -9,7 +9,7 @@ #include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" #include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" - +#ifdef DL_KERNELS namespace ck { namespace tensor_operation { namespace device { @@ -81,3 +81,4 @@ void add_device_conv2d_bwd_data_dl_nhwc_kyxc_nhwk_f16_instances( } // namespace device } // namespace tensor_operation } // namespace ck +#endif diff --git a/library/src/tensor_operation_instance/gpu/conv2d_bwd_data/device_conv2d_bwd_data_dl_nhwc_kyxc_nhwk_f32_instance.cpp b/library/src/tensor_operation_instance/gpu/conv2d_bwd_data/device_conv2d_bwd_data_dl_nhwc_kyxc_nhwk_f32_instance.cpp index f001b83c1711a7185d71f5b2c7f795005b35b696..d263e98851bed29a62f6aa8a8220f5015976e0c2 100644 --- a/library/src/tensor_operation_instance/gpu/conv2d_bwd_data/device_conv2d_bwd_data_dl_nhwc_kyxc_nhwk_f32_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/conv2d_bwd_data/device_conv2d_bwd_data_dl_nhwc_kyxc_nhwk_f32_instance.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #include @@ -9,7 +9,7 @@ #include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" #include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" - +#ifdef DL_KERNELS namespace ck { namespace tensor_operation { namespace device { @@ -81,3 +81,4 @@ void add_device_conv2d_bwd_data_dl_nhwc_kyxc_nhwk_f32_instances( } // namespace device } // namespace tensor_operation } // namespace ck +#endif diff --git a/library/src/tensor_operation_instance/gpu/conv2d_bwd_data/device_conv2d_bwd_data_dl_nhwc_kyxc_nhwk_int8_instance.cpp b/library/src/tensor_operation_instance/gpu/conv2d_bwd_data/device_conv2d_bwd_data_dl_nhwc_kyxc_nhwk_int8_instance.cpp index 83ba6a1c6bb8a3098da1b31a59a1d4a12b8b062e..bc949e757c0b0815cffb64aba59f130f4cc537fc 100644 --- a/library/src/tensor_operation_instance/gpu/conv2d_bwd_data/device_conv2d_bwd_data_dl_nhwc_kyxc_nhwk_int8_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/conv2d_bwd_data/device_conv2d_bwd_data_dl_nhwc_kyxc_nhwk_int8_instance.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #include @@ -9,7 +9,7 @@ #include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" #include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" - +#ifdef DL_KERNELS namespace ck { namespace tensor_operation { namespace device { @@ -81,3 +81,4 @@ void add_device_conv2d_bwd_data_dl_nhwc_kyxc_nhwk_int8_instances( } // namespace device } // namespace tensor_operation } // namespace ck +#endif diff --git a/library/src/tensor_operation_instance/gpu/conv2d_bwd_data/device_conv2d_bwd_data_xdl_nhwc_kyxc_nhwk_bf16_instance.cpp b/library/src/tensor_operation_instance/gpu/conv2d_bwd_data/device_conv2d_bwd_data_xdl_nhwc_kyxc_nhwk_bf16_instance.cpp index 1da9a81d904b3e43aaf5ace4f612c43c05e66ee0..366d1fe1602596d9033ccad392ac216299326110 100644 --- a/library/src/tensor_operation_instance/gpu/conv2d_bwd_data/device_conv2d_bwd_data_xdl_nhwc_kyxc_nhwk_bf16_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/conv2d_bwd_data/device_conv2d_bwd_data_xdl_nhwc_kyxc_nhwk_bf16_instance.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #include @@ -11,7 +11,7 @@ #include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" #include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" - +#ifdef CK_ENABLE_BF16 namespace ck { namespace tensor_operation { namespace device { @@ -155,3 +155,4 @@ void add_device_conv2d_bwd_data_xdl_nhwc_kyxc_nhwk_bf16_instances( } // namespace device } // namespace tensor_operation } // namespace ck +#endif diff --git a/library/src/tensor_operation_instance/gpu/conv2d_bwd_data/device_conv2d_bwd_data_xdl_nhwc_kyxc_nhwk_f16_instance.cpp b/library/src/tensor_operation_instance/gpu/conv2d_bwd_data/device_conv2d_bwd_data_xdl_nhwc_kyxc_nhwk_f16_instance.cpp index 7c33df5e7687ee3e279b350bec9ca17902e84231..422e37e926fd7a839afc2cf884b141ca32737078 100644 --- a/library/src/tensor_operation_instance/gpu/conv2d_bwd_data/device_conv2d_bwd_data_xdl_nhwc_kyxc_nhwk_f16_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/conv2d_bwd_data/device_conv2d_bwd_data_xdl_nhwc_kyxc_nhwk_f16_instance.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #include diff --git a/library/src/tensor_operation_instance/gpu/conv2d_bwd_data/device_conv2d_bwd_data_xdl_nhwc_kyxc_nhwk_f32_instance.cpp b/library/src/tensor_operation_instance/gpu/conv2d_bwd_data/device_conv2d_bwd_data_xdl_nhwc_kyxc_nhwk_f32_instance.cpp index a5f8629f2daef05ca1dbcbec2e8c4c668bf96a84..5993f6bd7a28c2992bf2e069880e84bf75612dfd 100644 --- a/library/src/tensor_operation_instance/gpu/conv2d_bwd_data/device_conv2d_bwd_data_xdl_nhwc_kyxc_nhwk_f32_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/conv2d_bwd_data/device_conv2d_bwd_data_xdl_nhwc_kyxc_nhwk_f32_instance.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #include diff --git a/library/src/tensor_operation_instance/gpu/conv2d_bwd_data/device_conv2d_bwd_data_xdl_nhwc_kyxc_nhwk_int8_instance.cpp b/library/src/tensor_operation_instance/gpu/conv2d_bwd_data/device_conv2d_bwd_data_xdl_nhwc_kyxc_nhwk_int8_instance.cpp index 8076d6d35669114d84b922a59948b888dcdab506..36610ae2054771810eda2d24430fcb541d19623e 100644 --- a/library/src/tensor_operation_instance/gpu/conv2d_bwd_data/device_conv2d_bwd_data_xdl_nhwc_kyxc_nhwk_int8_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/conv2d_bwd_data/device_conv2d_bwd_data_xdl_nhwc_kyxc_nhwk_int8_instance.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #include @@ -11,7 +11,7 @@ #include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" #include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" - +#ifdef CK_ENABLE_INT8 namespace ck { namespace tensor_operation { namespace device { @@ -151,3 +151,4 @@ void add_device_conv2d_bwd_data_xdl_nhwc_kyxc_nhwk_int8_instances( } // namespace device } // namespace tensor_operation } // namespace ck +#endif diff --git a/library/src/tensor_operation_instance/gpu/conv2d_fwd/CMakeLists.txt b/library/src/tensor_operation_instance/gpu/conv2d_fwd/CMakeLists.txt index 5b646852fc5c6c71d963000681b335f909a10286..96ecc9565820d937f80fd06684c4ef70cc9ab69a 100644 --- a/library/src/tensor_operation_instance/gpu/conv2d_fwd/CMakeLists.txt +++ b/library/src/tensor_operation_instance/gpu/conv2d_fwd/CMakeLists.txt @@ -1,7 +1,16 @@ -add_instance_library(device_conv2d_fwd_instance - device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_f32_instance.cpp - device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_f16_instance.cpp - device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_bf16_instance.cpp - device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_int8_instance.cpp - device_conv2d_fwd_xdl_c_shuffle_nhwc_kyxc_nhwk_f16_instance.cpp -) +set(DEVICE_CONV2D_FWD_INSTANCES) +if(DTYPES MATCHES "fp32" OR NOT DEFINED DTYPES) + list(APPEND DEVICE_CONV2D_FWD_INSTANCES device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_f32_instance.cpp) +endif() +if(DTYPES MATCHES "fp16" OR NOT DEFINED DTYPES) + list(APPEND DEVICE_CONV2D_FWD_INSTANCES device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_f16_instance.cpp) + list(APPEND DEVICE_CONV2D_FWD_INSTANCES device_conv2d_fwd_xdl_c_shuffle_nhwc_kyxc_nhwk_f16_instance.cpp) +endif() +if(DTYPES MATCHES "bf16" OR NOT DEFINED DTYPES) + list(APPEND DEVICE_CONV2D_FWD_INSTANCES device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_bf16_instance.cpp) +endif() +if(DTYPES MATCHES "int8" OR NOT DEFINED DTYPES) + list(APPEND DEVICE_CONV2D_FWD_INSTANCES device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_int8_instance.cpp) +endif() + +add_instance_library(device_conv2d_fwd_instance ${DEVICE_CONV2D_FWD_INSTANCES}) diff --git a/library/src/tensor_operation_instance/gpu/conv2d_fwd/device_conv2d_fwd_xdl_c_shuffle_nhwc_kyxc_nhwk_f16_instance.cpp b/library/src/tensor_operation_instance/gpu/conv2d_fwd/device_conv2d_fwd_xdl_c_shuffle_nhwc_kyxc_nhwk_f16_instance.cpp index 33503b9f8ae00a92c25b178629736c7e8a5474c3..da96c79a6e9a44edd05928b54a865483f6a1e1a2 100644 --- a/library/src/tensor_operation_instance/gpu/conv2d_fwd/device_conv2d_fwd_xdl_c_shuffle_nhwc_kyxc_nhwk_f16_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/conv2d_fwd/device_conv2d_fwd_xdl_c_shuffle_nhwc_kyxc_nhwk_f16_instance.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #include diff --git a/library/src/tensor_operation_instance/gpu/conv2d_fwd/device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_bf16_instance.cpp b/library/src/tensor_operation_instance/gpu/conv2d_fwd/device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_bf16_instance.cpp index c5e4bd199ea8322312af832b63a0ea559109131a..63c612523fbb53c2a4c3b676dc0bfe102155c8b5 100644 --- a/library/src/tensor_operation_instance/gpu/conv2d_fwd/device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_bf16_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/conv2d_fwd/device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_bf16_instance.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #include @@ -8,7 +8,7 @@ #include "ck/tensor_operation/gpu/device/impl/device_conv2d_fwd_xdl_nhwc_kyxc_nhwk.hpp" #include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" #include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" - +#ifdef CK_ENABLE_BF16 namespace ck { namespace tensor_operation { namespace device { @@ -126,3 +126,4 @@ void add_device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_bf16_instances( } // namespace device } // namespace tensor_operation } // namespace ck +#endif diff --git a/library/src/tensor_operation_instance/gpu/conv2d_fwd/device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_f16_instance.cpp b/library/src/tensor_operation_instance/gpu/conv2d_fwd/device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_f16_instance.cpp index f43d13e309333bc86abf22465b7f035c1445321b..0f3b9e7939aabfb70985e2bc7d5339b15df8eb2c 100644 --- a/library/src/tensor_operation_instance/gpu/conv2d_fwd/device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_f16_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/conv2d_fwd/device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_f16_instance.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #include @@ -8,7 +8,7 @@ #include "ck/tensor_operation/gpu/device/impl/device_conv2d_fwd_xdl_nhwc_kyxc_nhwk.hpp" #include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" #include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" - +#ifdef CK_ENABLE_FP16 namespace ck { namespace tensor_operation { namespace device { @@ -118,3 +118,4 @@ void add_device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_f16_instances( } // namespace device } // namespace tensor_operation } // namespace ck +#endif diff --git a/library/src/tensor_operation_instance/gpu/conv2d_fwd/device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_f32_instance.cpp b/library/src/tensor_operation_instance/gpu/conv2d_fwd/device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_f32_instance.cpp index 0ce6b04c4218190ccd3cad238f1bc90ec3002df9..14f9b5cd6ae5f9ff25c98df97bc310df6806f06e 100644 --- a/library/src/tensor_operation_instance/gpu/conv2d_fwd/device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_f32_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/conv2d_fwd/device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_f32_instance.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #include @@ -8,7 +8,7 @@ #include "ck/tensor_operation/gpu/device/impl/device_conv2d_fwd_xdl_nhwc_kyxc_nhwk.hpp" #include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" #include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" - +#ifdef CK_ENABLE_FP32 namespace ck { namespace tensor_operation { namespace device { @@ -117,3 +117,4 @@ void add_device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_f32_instances( } // namespace device } // namespace tensor_operation } // namespace ck +#endif diff --git a/library/src/tensor_operation_instance/gpu/conv2d_fwd/device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_int8_instance.cpp b/library/src/tensor_operation_instance/gpu/conv2d_fwd/device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_int8_instance.cpp index 76ab3189d7f210446b305a3daec80a2902de63ff..3f641cdadc69c31cb646d52c37c61ea2449bcaaa 100644 --- a/library/src/tensor_operation_instance/gpu/conv2d_fwd/device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_int8_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/conv2d_fwd/device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_int8_instance.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #include @@ -8,7 +8,7 @@ #include "ck/tensor_operation/gpu/device/impl/device_conv2d_fwd_xdl_nhwc_kyxc_nhwk.hpp" #include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" #include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" - +#ifdef CK_ENABLE_INT8 namespace ck { namespace tensor_operation { namespace device { @@ -123,3 +123,4 @@ void add_device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_int8_instances( } // namespace device } // namespace tensor_operation } // namespace ck +#endif diff --git a/library/src/tensor_operation_instance/gpu/conv2d_fwd_bias_relu/device_conv2d_fwd_xdl_c_shuffle_bias_relu_nhwc_kyxc_nhwk_f16_instance.cpp b/library/src/tensor_operation_instance/gpu/conv2d_fwd_bias_relu/device_conv2d_fwd_xdl_c_shuffle_bias_relu_nhwc_kyxc_nhwk_f16_instance.cpp index f8c255088877e95e2cbbbacbfba6f9e1f4fc0291..e34ea06ff46ba090a212575e66dcfd454a44de66 100644 --- a/library/src/tensor_operation_instance/gpu/conv2d_fwd_bias_relu/device_conv2d_fwd_xdl_c_shuffle_bias_relu_nhwc_kyxc_nhwk_f16_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/conv2d_fwd_bias_relu/device_conv2d_fwd_xdl_c_shuffle_bias_relu_nhwc_kyxc_nhwk_f16_instance.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #include diff --git a/library/src/tensor_operation_instance/gpu/conv2d_fwd_bias_relu_add/device_conv2d_fwd_xdl_c_shuffle_bias_relu_add_nhwc_kyxc_nhwk_f16_instance.cpp b/library/src/tensor_operation_instance/gpu/conv2d_fwd_bias_relu_add/device_conv2d_fwd_xdl_c_shuffle_bias_relu_add_nhwc_kyxc_nhwk_f16_instance.cpp index fe7152471ed54aa5334e820f993232912013c6de..3254fcfc26467edc3ec6b14d1838e784d7438529 100644 --- a/library/src/tensor_operation_instance/gpu/conv2d_fwd_bias_relu_add/device_conv2d_fwd_xdl_c_shuffle_bias_relu_add_nhwc_kyxc_nhwk_f16_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/conv2d_fwd_bias_relu_add/device_conv2d_fwd_xdl_c_shuffle_bias_relu_add_nhwc_kyxc_nhwk_f16_instance.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #include diff --git a/library/src/tensor_operation_instance/gpu/conv3d_bwd_data/device_conv3d_bwd_data_xdl_ndhwc_kzyxc_ndhwk_bf16_instance.cpp b/library/src/tensor_operation_instance/gpu/conv3d_bwd_data/device_conv3d_bwd_data_xdl_ndhwc_kzyxc_ndhwk_bf16_instance.cpp index 04ce7c07639c968d90a8053f4b67b1335acc9b69..94b2a47e50b635e5b862c109f479660183d2bede 100644 --- a/library/src/tensor_operation_instance/gpu/conv3d_bwd_data/device_conv3d_bwd_data_xdl_ndhwc_kzyxc_ndhwk_bf16_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/conv3d_bwd_data/device_conv3d_bwd_data_xdl_ndhwc_kzyxc_ndhwk_bf16_instance.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #include diff --git a/library/src/tensor_operation_instance/gpu/conv3d_bwd_data/device_conv3d_bwd_data_xdl_ndhwc_kzyxc_ndhwk_f16_instance.cpp b/library/src/tensor_operation_instance/gpu/conv3d_bwd_data/device_conv3d_bwd_data_xdl_ndhwc_kzyxc_ndhwk_f16_instance.cpp index 0251d9157f3b894208b17091ab6c79eec8bf605d..4244ab7b8770c9cbd60957c4fd66f004b7f5975c 100644 --- a/library/src/tensor_operation_instance/gpu/conv3d_bwd_data/device_conv3d_bwd_data_xdl_ndhwc_kzyxc_ndhwk_f16_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/conv3d_bwd_data/device_conv3d_bwd_data_xdl_ndhwc_kzyxc_ndhwk_f16_instance.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #include diff --git a/library/src/tensor_operation_instance/gpu/conv3d_bwd_data/device_conv3d_bwd_data_xdl_ndhwc_kzyxc_ndhwk_f32_instance.cpp b/library/src/tensor_operation_instance/gpu/conv3d_bwd_data/device_conv3d_bwd_data_xdl_ndhwc_kzyxc_ndhwk_f32_instance.cpp index c2975727e48385207678abdcbc1fda776e761664..5c7db4ca3bc4dce89cd1d139722d1e91758f8d7a 100644 --- a/library/src/tensor_operation_instance/gpu/conv3d_bwd_data/device_conv3d_bwd_data_xdl_ndhwc_kzyxc_ndhwk_f32_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/conv3d_bwd_data/device_conv3d_bwd_data_xdl_ndhwc_kzyxc_ndhwk_f32_instance.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #include diff --git a/library/src/tensor_operation_instance/gpu/conv3d_bwd_data/device_conv3d_bwd_data_xdl_ndhwc_kzyxc_ndhwk_int8_instance.cpp b/library/src/tensor_operation_instance/gpu/conv3d_bwd_data/device_conv3d_bwd_data_xdl_ndhwc_kzyxc_ndhwk_int8_instance.cpp index fc86d7302450a01263dc07d8f895c669795f3644..ebc56487a159afe8b32e5e321c6996c38250c32f 100644 --- a/library/src/tensor_operation_instance/gpu/conv3d_bwd_data/device_conv3d_bwd_data_xdl_ndhwc_kzyxc_ndhwk_int8_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/conv3d_bwd_data/device_conv3d_bwd_data_xdl_ndhwc_kzyxc_ndhwk_int8_instance.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #include diff --git a/library/src/tensor_operation_instance/gpu/elementwise/device_normalize_instance.cpp b/library/src/tensor_operation_instance/gpu/elementwise/device_normalize_instance.cpp index 182037f15c66c2d196e86b84ca0f0571be2922e9..f2a5f0728ac1d8bc4bff78aa5476e872a7085163 100644 --- a/library/src/tensor_operation_instance/gpu/elementwise/device_normalize_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/elementwise/device_normalize_instance.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #include @@ -30,7 +30,12 @@ using device_normalize_from_mean_squaremean_f16_f32_f32_f16_f16_instances = std: //###################|| | functor| NDim| MPerThread| | | DeviceElementwiseImpl, Tuple, Normalize, 2, 8, Sequence<8, 1, 1, 8, 8>, Sequence<8> >, DeviceElementwiseImpl, Tuple, Normalize, 2, 4, Sequence<4, 1, 1, 4, 4>, Sequence<4> >, - DeviceElementwiseImpl, Tuple, Normalize, 2, 2, Sequence<2, 1, 1, 2, 2>, Sequence<2> >, + DeviceElementwiseImpl, Tuple, Normalize, 2, 2, Sequence<2, 1, 1, 2, 2>, Sequence<2> > + // clang-format on + >; + +using device_normalize_from_mean_squaremean_f16_f32_f32_f16_f16_generic_instance = std::tuple< + // clang-format off DeviceElementwiseImpl, Tuple, Normalize, 2, 1, Sequence<1, 1, 1, 1, 1>, Sequence<1> > // clang-format on >; @@ -39,6 +44,9 @@ void add_device_normalize_from_mean_squaremean_f16_f32_f32_f16_f16_instances( std::vector, Tuple, Normalize, 2>>& instances) { + add_device_operation_instances( + instances, device_normalize_from_mean_squaremean_f16_f32_f32_f16_f16_generic_instance{}); + add_device_operation_instances( instances, device_normalize_from_mean_squaremean_f16_f32_f32_f16_f16_instances{}); } diff --git a/library/src/tensor_operation_instance/gpu/elementwise_normalization/CMakeLists.txt b/library/src/tensor_operation_instance/gpu/elementwise_normalization/CMakeLists.txt index 0c7cc2cd31288b2f2aaec0dc49f92165b912c225..ea53b82be57a25cb90977a0a097ad49ec1be4127 100644 --- a/library/src/tensor_operation_instance/gpu/elementwise_normalization/CMakeLists.txt +++ b/library/src/tensor_operation_instance/gpu/elementwise_normalization/CMakeLists.txt @@ -1,3 +1,5 @@ +if(DTYPES MATCHES "fp16" OR NOT DEFINED DTYPES) add_instance_library(device_elementwise_normalization_instance device_elementwise_normalization_f16_instance.cpp ) +endif() diff --git a/library/src/tensor_operation_instance/gpu/elementwise_normalization/device_elementwise_normalization_f16_instance.cpp b/library/src/tensor_operation_instance/gpu/elementwise_normalization/device_elementwise_normalization_f16_instance.cpp index b160d4fe1a3a271113899b2a8bc57e5535706280..3e2386ee03558ec6d346abc73a037e65e2d9cde3 100644 --- a/library/src/tensor_operation_instance/gpu/elementwise_normalization/device_elementwise_normalization_f16_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/elementwise_normalization/device_elementwise_normalization_f16_instance.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #include "ck/ck.hpp" #include "ck/tensor_operation/gpu/device/impl/device_elementwise_normalization_impl.hpp" diff --git a/library/src/tensor_operation_instance/gpu/gemm/CMakeLists.txt b/library/src/tensor_operation_instance/gpu/gemm/CMakeLists.txt index d66010af734581ae204015e7f3d663040d0b111e..48dd292f05ffaa4517659ff97819cf725631cc30 100644 --- a/library/src/tensor_operation_instance/gpu/gemm/CMakeLists.txt +++ b/library/src/tensor_operation_instance/gpu/gemm/CMakeLists.txt @@ -1,51 +1,137 @@ -add_instance_library(device_gemm_instance - device_gemm_xdl_f64_f64_f64_mk_kn_mn_instance.cpp - device_gemm_xdl_f64_f64_f64_mk_nk_mn_instance.cpp - device_gemm_xdl_f64_f64_f64_km_kn_mn_instance.cpp - device_gemm_xdl_f64_f64_f64_km_nk_mn_instance.cpp - device_gemm_xdl_f32_f32_f32_mk_kn_mn_instance.cpp - device_gemm_xdl_f32_f32_f32_mk_nk_mn_instance.cpp - device_gemm_xdl_f32_f32_f32_km_kn_mn_instance.cpp - device_gemm_xdl_f32_f32_f32_km_nk_mn_instance.cpp - device_gemm_xdl_f16_f16_f16_mk_kn_mn_instance.cpp - device_gemm_xdl_f16_f16_f16_mk_nk_mn_instance.cpp - device_gemm_xdl_f16_f16_f16_km_kn_mn_instance.cpp - device_gemm_xdl_f16_f16_f16_km_nk_mn_instance.cpp - device_gemm_xdl_c_shuffle_i8_i8_i8_mk_kn_mn_instance.cpp - device_gemm_xdl_c_shuffle_i8_i8_i8_mk_nk_mn_instance.cpp - device_gemm_xdl_c_shuffle_i8_i8_i8_km_kn_mn_instance.cpp - device_gemm_xdl_c_shuffle_i8_i8_i8_km_nk_mn_instance.cpp - device_gemm_xdl_c_shuffle_bf16_bf16_bf16_mk_kn_mn_instance.cpp - device_gemm_xdl_c_shuffle_bf16_bf16_bf16_mk_nk_mn_instance.cpp - device_gemm_xdl_c_shuffle_bf16_bf16_bf16_km_kn_mn_instance.cpp - device_gemm_xdl_c_shuffle_bf16_bf16_bf16_km_nk_mn_instance.cpp - device_gemm_xdl_c_shuffle_f16_f16_f16_mk_kn_mn_instance.cpp - device_gemm_xdl_c_shuffle_f16_f16_f16_mk_nk_mn_instance.cpp - device_gemm_xdl_c_shuffle_f16_f16_f16_km_kn_mn_instance.cpp - device_gemm_xdl_c_shuffle_f16_f16_f16_km_nk_mn_instance.cpp - device_gemm_xdl_c_shuffle_f32_f32_f32_mk_kn_mn_instance.cpp - device_gemm_xdl_c_shuffle_f32_f32_f32_mk_nk_mn_instance.cpp - device_gemm_xdl_c_shuffle_f32_f32_f32_km_kn_mn_instance.cpp - device_gemm_xdl_c_shuffle_f32_f32_f32_km_nk_mn_instance.cpp - device_gemm_xdl_c_shuffle_2_stage_f16_f16_f16_mk_nk_mn_instance.cpp - device_gemm_dl_f32_f32_f32_mk_kn_mn_instance.cpp - device_gemm_dl_f32_f32_f32_mk_nk_mn_instance.cpp - device_gemm_dl_f32_f32_f32_km_kn_mn_instance.cpp - device_gemm_dl_f32_f32_f32_km_nk_mn_instance.cpp - device_gemm_dl_f16_f16_f16_mk_kn_mn_instance.cpp - device_gemm_dl_f16_f16_f16_mk_kn_mn_irregular_instance.cpp - device_gemm_dl_f16_f16_f16_mk_nk_mn_instance.cpp - device_gemm_dl_f16_f16_f16_mk_nk_mn_irregular_instance.cpp - device_gemm_dl_f16_f16_f16_km_kn_mn_instance.cpp - device_gemm_dl_f16_f16_f16_km_kn_mn_irregular_instance.cpp - device_gemm_dl_f16_f16_f16_km_nk_mn_instance.cpp - device_gemm_dl_f16_f16_f16_km_nk_mn_irregular_instance.cpp - device_gemm_dl_i8_i8_i8_mk_kn_mn_instance.cpp - device_gemm_dl_i8_i8_i8_mk_kn_mn_irregular_instance.cpp - device_gemm_dl_i8_i8_i8_mk_nk_mn_instance.cpp - device_gemm_dl_i8_i8_i8_mk_nk_mn_irregular_instance.cpp - device_gemm_dl_i8_i8_i8_km_kn_mn_instance.cpp - device_gemm_dl_i8_i8_i8_km_kn_mn_irregular_instance.cpp - device_gemm_dl_i8_i8_i8_km_nk_mn_instance.cpp - device_gemm_dl_i8_i8_i8_km_nk_mn_irregular_instance.cpp -) +set(GEMM_INSTANCES) +if(DTYPES MATCHES "fp64" OR NOT DEFINED DTYPES) + list(APPEND GEMM_INSTANCES device_gemm_xdl_f64_f64_f64_mk_kn_mn_instance.cpp) + list(APPEND GEMM_INSTANCES device_gemm_xdl_f64_f64_f64_mk_nk_mn_instance.cpp) + list(APPEND GEMM_INSTANCES device_gemm_xdl_f64_f64_f64_km_kn_mn_instance.cpp) + list(APPEND GEMM_INSTANCES device_gemm_xdl_f64_f64_f64_km_nk_mn_instance.cpp) +endif() +if(DTYPES MATCHES "fp32" OR NOT DEFINED DTYPES) + list(APPEND GEMM_INSTANCES device_gemm_xdl_f32_f32_f32_mk_kn_mn_instance.cpp) + list(APPEND GEMM_INSTANCES device_gemm_xdl_f32_f32_f32_mk_nk_mn_instance.cpp) + list(APPEND GEMM_INSTANCES device_gemm_xdl_f32_f32_f32_km_kn_mn_instance.cpp) + list(APPEND GEMM_INSTANCES device_gemm_xdl_f32_f32_f32_km_nk_mn_instance.cpp) + list(APPEND GEMM_INSTANCES device_gemm_xdl_c_shuffle_f32_f32_f32_mk_kn_mn_instance.cpp) + list(APPEND GEMM_INSTANCES device_gemm_xdl_c_shuffle_f32_f32_f32_mk_nk_mn_instance.cpp) + list(APPEND GEMM_INSTANCES device_gemm_xdl_c_shuffle_f32_f32_f32_km_kn_mn_instance.cpp) + list(APPEND GEMM_INSTANCES device_gemm_xdl_c_shuffle_f32_f32_f32_km_nk_mn_instance.cpp) + if(DL_KERNELS) + list(APPEND GEMM_INSTANCES device_gemm_dl_f32_f32_f32_mk_kn_mn_instance.cpp) + list(APPEND GEMM_INSTANCES device_gemm_dl_f32_f32_f32_mk_nk_mn_instance.cpp) + list(APPEND GEMM_INSTANCES device_gemm_dl_f32_f32_f32_km_kn_mn_instance.cpp) + list(APPEND GEMM_INSTANCES device_gemm_dl_f32_f32_f32_km_nk_mn_instance.cpp) + endif() +endif() +if(DTYPES MATCHES "fp16" OR NOT DEFINED DTYPES) + if(DL_KERNELS) + list(APPEND GEMM_INSTANCES device_gemm_dl_f16_f16_f16_mk_kn_mn_instance.cpp) + list(APPEND GEMM_INSTANCES device_gemm_dl_f16_f16_f16_mk_kn_mn_irregular_instance.cpp) + list(APPEND GEMM_INSTANCES device_gemm_dl_f16_f16_f16_mk_nk_mn_instance.cpp) + list(APPEND GEMM_INSTANCES device_gemm_dl_f16_f16_f16_mk_nk_mn_irregular_instance.cpp) + list(APPEND GEMM_INSTANCES device_gemm_dl_f16_f16_f16_km_kn_mn_instance.cpp) + list(APPEND GEMM_INSTANCES device_gemm_dl_f16_f16_f16_km_kn_mn_irregular_instance.cpp) + list(APPEND GEMM_INSTANCES device_gemm_dl_f16_f16_f16_km_nk_mn_instance.cpp) + list(APPEND GEMM_INSTANCES device_gemm_dl_f16_f16_f16_km_nk_mn_irregular_instance.cpp) + list(APPEND GEMM_INSTANCES device_gemm_dpp_f16_f16_f16_km_kn_mn_instance.cpp) + list(APPEND GEMM_INSTANCES device_gemm_dpp_f16_f16_f16_km_nk_mn_instance.cpp) + list(APPEND GEMM_INSTANCES device_gemm_dpp_f16_f16_f16_mk_kn_mn_instance.cpp) + list(APPEND GEMM_INSTANCES device_gemm_dpp_f16_f16_f16_mk_nk_mn_instance.cpp) + endif() + list(APPEND GEMM_INSTANCES device_gemm_xdl_c_shuffle_f16_f16_f16_mk_kn_mn_instance.cpp) + list(APPEND GEMM_INSTANCES device_gemm_xdl_c_shuffle_f16_f16_f16_mk_nk_mn_instance.cpp) + list(APPEND GEMM_INSTANCES device_gemm_xdl_c_shuffle_f16_f16_f16_km_kn_mn_instance.cpp) + list(APPEND GEMM_INSTANCES device_gemm_xdl_c_shuffle_f16_f16_f16_km_nk_mn_instance.cpp) + list(APPEND GEMM_INSTANCES device_gemm_xdl_c_shuffle_2_stage_f16_f16_f16_mk_nk_mn_instance.cpp) + list(APPEND GEMM_INSTANCES device_gemm_xdl_f16_f16_f16/km_kn_mn_add_instance.cpp) + list(APPEND GEMM_INSTANCES device_gemm_xdl_f16_f16_f16/km_kn_mn_default_pipeline_v1_instance.cpp) + list(APPEND GEMM_INSTANCES device_gemm_xdl_f16_f16_f16/km_kn_mn_default_pipeline_v2_instance.cpp) + list(APPEND GEMM_INSTANCES device_gemm_xdl_f16_f16_f16/km_kn_mn_default_pipeline_v2_opt_instance.cpp) + list(APPEND GEMM_INSTANCES device_gemm_xdl_f16_f16_f16/km_kn_mn_interwave_pipeline_v1_instance.cpp) + list(APPEND GEMM_INSTANCES device_gemm_xdl_f16_f16_f16/km_kn_mn_irregular_default_pipeline_v1_instance.cpp) + list(APPEND GEMM_INSTANCES device_gemm_xdl_f16_f16_f16/km_kn_mn_irregular_default_pipeline_v2_instance.cpp) + list(APPEND GEMM_INSTANCES device_gemm_xdl_f16_f16_f16/km_kn_mn_irregular_interwave_pipeline_v1_instance.cpp) + list(APPEND GEMM_INSTANCES device_gemm_xdl_f16_f16_f16/km_nk_mn_add_instance.cpp) + list(APPEND GEMM_INSTANCES device_gemm_xdl_f16_f16_f16/km_nk_mn_default_pipeline_v1_instance.cpp) + list(APPEND GEMM_INSTANCES device_gemm_xdl_f16_f16_f16/km_nk_mn_default_pipeline_v2_instance.cpp) + list(APPEND GEMM_INSTANCES device_gemm_xdl_f16_f16_f16/km_nk_mn_default_pipeline_v2_opt_instance.cpp) + list(APPEND GEMM_INSTANCES device_gemm_xdl_f16_f16_f16/km_nk_mn_interwave_pipeline_v1_instance.cpp) + list(APPEND GEMM_INSTANCES device_gemm_xdl_f16_f16_f16/km_nk_mn_irregular_default_pipeline_v1_instance.cpp) + list(APPEND GEMM_INSTANCES device_gemm_xdl_f16_f16_f16/km_nk_mn_irregular_default_pipeline_v2_instance.cpp) + list(APPEND GEMM_INSTANCES device_gemm_xdl_f16_f16_f16/km_nk_mn_irregular_interwave_pipeline_v1_instance.cpp) + list(APPEND GEMM_INSTANCES device_gemm_xdl_f16_f16_f16/mk_kn_mn_add_instance.cpp) + list(APPEND GEMM_INSTANCES device_gemm_xdl_f16_f16_f16/mk_kn_mn_default_pipeline_v1_instance.cpp) + list(APPEND GEMM_INSTANCES device_gemm_xdl_f16_f16_f16/mk_kn_mn_default_pipeline_v2_instance.cpp) + list(APPEND GEMM_INSTANCES device_gemm_xdl_f16_f16_f16/mk_kn_mn_default_pipeline_v2_opt_instance.cpp) + list(APPEND GEMM_INSTANCES device_gemm_xdl_f16_f16_f16/mk_kn_mn_interwave_pipeline_v1_instance.cpp) + list(APPEND GEMM_INSTANCES device_gemm_xdl_f16_f16_f16/mk_kn_mn_irregular_default_pipeline_v1_instance.cpp) + list(APPEND GEMM_INSTANCES device_gemm_xdl_f16_f16_f16/mk_kn_mn_irregular_default_pipeline_v2_instance.cpp) + list(APPEND GEMM_INSTANCES device_gemm_xdl_f16_f16_f16/mk_kn_mn_irregular_interwave_pipeline_v1_instance.cpp) + list(APPEND GEMM_INSTANCES device_gemm_xdl_f16_f16_f16/mk_nk_mn_add_instance.cpp) + list(APPEND GEMM_INSTANCES device_gemm_xdl_f16_f16_f16/mk_nk_mn_default_pipeline_v1_instance.cpp) + list(APPEND GEMM_INSTANCES device_gemm_xdl_f16_f16_f16/mk_nk_mn_default_pipeline_v2_instance.cpp) + list(APPEND GEMM_INSTANCES device_gemm_xdl_f16_f16_f16/mk_nk_mn_default_pipeline_v2_opt_instance.cpp) + list(APPEND GEMM_INSTANCES device_gemm_xdl_f16_f16_f16/mk_nk_mn_interwave_pipeline_v1_instance.cpp) + list(APPEND GEMM_INSTANCES device_gemm_xdl_f16_f16_f16/mk_nk_mn_irregular_default_pipeline_v1_instance.cpp) + list(APPEND GEMM_INSTANCES device_gemm_xdl_f16_f16_f16/mk_nk_mn_irregular_default_pipeline_v2_instance.cpp) + list(APPEND GEMM_INSTANCES device_gemm_xdl_f16_f16_f16/mk_nk_mn_irregular_interwave_pipeline_v1_instance.cpp) +endif() +if(DTYPES MATCHES "int8" OR NOT DEFINED DTYPES) + if(DL_KERNELS) + list(APPEND GEMM_INSTANCES device_gemm_dl_i8_i8_i8_mk_kn_mn_instance.cpp) + list(APPEND GEMM_INSTANCES device_gemm_dl_i8_i8_i8_mk_kn_mn_irregular_instance.cpp) + list(APPEND GEMM_INSTANCES device_gemm_dl_i8_i8_i8_mk_nk_mn_instance.cpp) + list(APPEND GEMM_INSTANCES device_gemm_dl_i8_i8_i8_mk_nk_mn_irregular_instance.cpp) + list(APPEND GEMM_INSTANCES device_gemm_dl_i8_i8_i8_km_kn_mn_instance.cpp) + list(APPEND GEMM_INSTANCES device_gemm_dl_i8_i8_i8_km_kn_mn_irregular_instance.cpp) + list(APPEND GEMM_INSTANCES device_gemm_dl_i8_i8_i8_km_nk_mn_instance.cpp) + list(APPEND GEMM_INSTANCES device_gemm_dl_i8_i8_i8_km_nk_mn_irregular_instance.cpp) + endif() + list(APPEND GEMM_INSTANCES device_gemm_xdl_c_shuffle_i8_i8_i8_mk_kn_mn_instance.cpp) + list(APPEND GEMM_INSTANCES device_gemm_xdl_c_shuffle_i8_i8_i8_mk_nk_mn_instance.cpp) + list(APPEND GEMM_INSTANCES device_gemm_xdl_c_shuffle_i8_i8_i8_km_kn_mn_instance.cpp) + list(APPEND GEMM_INSTANCES device_gemm_xdl_c_shuffle_i8_i8_i8_km_nk_mn_instance.cpp) +endif() +if(DTYPES MATCHES "bf16" OR NOT DEFINED DTYPES) + list(APPEND GEMM_INSTANCES device_gemm_xdl_c_shuffle_bf16_bf16_bf16_mk_kn_mn_instance.cpp) + list(APPEND GEMM_INSTANCES device_gemm_xdl_c_shuffle_bf16_bf16_bf16_mk_nk_mn_instance.cpp) + list(APPEND GEMM_INSTANCES device_gemm_xdl_c_shuffle_bf16_bf16_bf16_km_kn_mn_instance.cpp) + list(APPEND GEMM_INSTANCES device_gemm_xdl_c_shuffle_bf16_bf16_bf16_km_nk_mn_instance.cpp) +endif() + +add_instance_library(device_gemm_instance ${GEMM_INSTANCES}) + + +if(DTYPES MATCHES "fp16" OR NOT DEFINED DTYPES) + set(ENABLE_PIPELINE_V2_OPT OFF) + + if (ENABLE_PIPELINE_V2_OPT) + set(MAX_ILP_OPTS + -mllvm + -amdgpu-enable-max-ilp-scheduling-strategy + ) + set(WAVES_PER_EU_DEFS + CK_USE_WAVES_PER_EU=1 + CK_MIN_WAVES_PER_EU=1 + CK_MAX_WAVES_PER_EU=1 + ) + set(IGLP_OPT_DEFS + CK_EXPERIMENTAL_PIPELINE_V2_IGLP_OPT=1 + ) + + # layout=NT + set_source_files_properties(device_gemm_xdl_f16_f16_f16/km_kn_mn_default_pipeline_v2_opt_instance.cpp PROPERTIES + COMPILE_OPTIONS ";;" + COMPILE_DEFINITIONS "${WAVES_PER_EU_DEFS};${IGLP_OPT_DEFS}") + # layout=NN + set_source_files_properties(device_gemm_xdl_f16_f16_f16/km_nk_mn_default_pipeline_v2_opt_instance.cpp PROPERTIES + COMPILE_OPTIONS "${MAX_ILP_OPTS}" + COMPILE_DEFINITIONS "${WAVES_PER_EU_DEFS};${IGLP_OPT_DEFS}") + # layout=TT + set_source_files_properties(device_gemm_xdl_f16_f16_f16/mk_kn_mn_default_pipeline_v2_opt_instance.cpp PROPERTIES + COMPILE_OPTIONS ";;" + COMPILE_DEFINITIONS "${WAVES_PER_EU_DEFS};${IGLP_OPT_DEFS}") + # layout=TN + set_source_files_properties(device_gemm_xdl_f16_f16_f16/mk_nk_mn_default_pipeline_v2_opt_instance.cpp PROPERTIES + COMPILE_OPTIONS "${MAX_ILP_OPTS}" + COMPILE_DEFINITIONS "${WAVES_PER_EU_DEFS};${IGLP_OPT_DEFS}") + endif(ENABLE_PIPELINE_V2_OPT) +endif(DTYPES MATCHES "fp16" OR NOT DEFINED DTYPES) diff --git a/library/src/tensor_operation_instance/gpu/gemm/device_gemm_dl_f16_f16_f16_km_kn_mn_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm/device_gemm_dl_f16_f16_f16_km_kn_mn_instance.cpp index 35df85b7afe115b2048c8cb5a586af561e79412e..ea99a5a30ef80bdb954df78535a6cd63b8663bbb 100644 --- a/library/src/tensor_operation_instance/gpu/gemm/device_gemm_dl_f16_f16_f16_km_kn_mn_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/gemm/device_gemm_dl_f16_f16_f16_km_kn_mn_instance.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #include diff --git a/library/src/tensor_operation_instance/gpu/gemm/device_gemm_dl_f16_f16_f16_km_nk_mn_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm/device_gemm_dl_f16_f16_f16_km_nk_mn_instance.cpp index 7d0863c953788d35f1a4730323ba0700c5195ab0..b83acfa8cb4db572bb33ccd20e6d38e7594f0234 100644 --- a/library/src/tensor_operation_instance/gpu/gemm/device_gemm_dl_f16_f16_f16_km_nk_mn_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/gemm/device_gemm_dl_f16_f16_f16_km_nk_mn_instance.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #include diff --git a/library/src/tensor_operation_instance/gpu/gemm/device_gemm_dl_f16_f16_f16_mk_kn_mn_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm/device_gemm_dl_f16_f16_f16_mk_kn_mn_instance.cpp index c1b11b19df4470f2cd02afc3483e8ce93ee05451..d5800e03336e9161e01ef240a46aefbf9b13c2c0 100644 --- a/library/src/tensor_operation_instance/gpu/gemm/device_gemm_dl_f16_f16_f16_mk_kn_mn_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/gemm/device_gemm_dl_f16_f16_f16_mk_kn_mn_instance.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #include diff --git a/library/src/tensor_operation_instance/gpu/gemm/device_gemm_dl_f16_f16_f16_mk_nk_mn_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm/device_gemm_dl_f16_f16_f16_mk_nk_mn_instance.cpp index 9a889e8d551e1dca2a9c7920bf0f4db93aba3507..abe52ce1da0edb191755d205a4e744c274ae56b1 100644 --- a/library/src/tensor_operation_instance/gpu/gemm/device_gemm_dl_f16_f16_f16_mk_nk_mn_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/gemm/device_gemm_dl_f16_f16_f16_mk_nk_mn_instance.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #include diff --git a/library/src/tensor_operation_instance/gpu/gemm/device_gemm_dl_f32_f32_f32_km_kn_mn_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm/device_gemm_dl_f32_f32_f32_km_kn_mn_instance.cpp index d8e7798438dcaaa9d7c2e48e8074a2f58a635438..e696bfdcdc07f6799a00b9463837031559b715a5 100644 --- a/library/src/tensor_operation_instance/gpu/gemm/device_gemm_dl_f32_f32_f32_km_kn_mn_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/gemm/device_gemm_dl_f32_f32_f32_km_kn_mn_instance.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #include diff --git a/library/src/tensor_operation_instance/gpu/gemm/device_gemm_dl_f32_f32_f32_km_nk_mn_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm/device_gemm_dl_f32_f32_f32_km_nk_mn_instance.cpp index 0034ac59c3888d8dcc1b2bf8e136ac4f02c2e6ad..d3ad7c60ec6f8cf2321aaa03e18e62ac28e65773 100644 --- a/library/src/tensor_operation_instance/gpu/gemm/device_gemm_dl_f32_f32_f32_km_nk_mn_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/gemm/device_gemm_dl_f32_f32_f32_km_nk_mn_instance.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #include diff --git a/library/src/tensor_operation_instance/gpu/gemm/device_gemm_dl_f32_f32_f32_mk_kn_mn_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm/device_gemm_dl_f32_f32_f32_mk_kn_mn_instance.cpp index 0b540b8b3492c355d75adf98cacc1826417db0b9..a56a36b0ab9cf0260a0366560d8816f18f1f5d7e 100644 --- a/library/src/tensor_operation_instance/gpu/gemm/device_gemm_dl_f32_f32_f32_mk_kn_mn_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/gemm/device_gemm_dl_f32_f32_f32_mk_kn_mn_instance.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #include diff --git a/library/src/tensor_operation_instance/gpu/gemm/device_gemm_dl_f32_f32_f32_mk_nk_mn_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm/device_gemm_dl_f32_f32_f32_mk_nk_mn_instance.cpp index 4f6ff5111b1fa5dbf39e4c60d1770a319aa9dbba..63d55e81d6a5cb1157ea72a97e01599e6fc750ff 100644 --- a/library/src/tensor_operation_instance/gpu/gemm/device_gemm_dl_f32_f32_f32_mk_nk_mn_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/gemm/device_gemm_dl_f32_f32_f32_mk_nk_mn_instance.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #include diff --git a/library/src/tensor_operation_instance/gpu/gemm/device_gemm_dl_i8_i8_i8_km_kn_mn_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm/device_gemm_dl_i8_i8_i8_km_kn_mn_instance.cpp index 9f5cebcab0e5d11f76efb741e7e33ebec18aef5b..05f399e218fe40f8737c5b04203a61cf2ed1acc8 100644 --- a/library/src/tensor_operation_instance/gpu/gemm/device_gemm_dl_i8_i8_i8_km_kn_mn_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/gemm/device_gemm_dl_i8_i8_i8_km_kn_mn_instance.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #include @@ -8,7 +8,7 @@ #include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" #include "ck/tensor_operation/gpu/device/impl/device_gemm_dl.hpp" #include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" - +#ifdef CK_ENABLE_INT8 namespace ck { namespace tensor_operation { namespace device { @@ -80,3 +80,4 @@ void add_device_gemm_dl_i8_i8_i8_km_kn_mn_instances( } // namespace device } // namespace tensor_operation } // namespace ck +#endif diff --git a/library/src/tensor_operation_instance/gpu/gemm/device_gemm_dl_i8_i8_i8_km_kn_mn_irregular_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm/device_gemm_dl_i8_i8_i8_km_kn_mn_irregular_instance.cpp index 0439201511f285a0d7eba405df72d1942a4fb49b..d1eb8edf979735ef3801bdb4dfc77db266d05fc4 100644 --- a/library/src/tensor_operation_instance/gpu/gemm/device_gemm_dl_i8_i8_i8_km_kn_mn_irregular_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/gemm/device_gemm_dl_i8_i8_i8_km_kn_mn_irregular_instance.cpp @@ -8,7 +8,7 @@ #include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" #include "ck/tensor_operation/gpu/device/impl/device_gemm_dl.hpp" #include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" - +#ifdef CK_ENABLE_INT8 namespace ck { namespace tensor_operation { namespace device { @@ -78,3 +78,4 @@ void add_device_gemm_dl_i8_i8_i8_km_kn_mn_irregular_instances( } // namespace device } // namespace tensor_operation } // namespace ck +#endif diff --git a/library/src/tensor_operation_instance/gpu/gemm/device_gemm_dl_i8_i8_i8_km_nk_mn_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm/device_gemm_dl_i8_i8_i8_km_nk_mn_instance.cpp index 41afb519f17b036730c975b2ff0ca6b3fef96bce..8df3bb9641055c414943573b127ebc1a39c35cde 100644 --- a/library/src/tensor_operation_instance/gpu/gemm/device_gemm_dl_i8_i8_i8_km_nk_mn_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/gemm/device_gemm_dl_i8_i8_i8_km_nk_mn_instance.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #include @@ -8,7 +8,7 @@ #include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" #include "ck/tensor_operation/gpu/device/impl/device_gemm_dl.hpp" #include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" - +#ifdef CK_ENABLE_INT8 namespace ck { namespace tensor_operation { namespace device { @@ -80,3 +80,4 @@ void add_device_gemm_dl_i8_i8_i8_km_nk_mn_instances( } // namespace device } // namespace tensor_operation } // namespace ck +#endif diff --git a/library/src/tensor_operation_instance/gpu/gemm/device_gemm_dl_i8_i8_i8_km_nk_mn_irregular_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm/device_gemm_dl_i8_i8_i8_km_nk_mn_irregular_instance.cpp index 350834f7e515b93645e2f22285c19f9b9285304c..3de5458e05ff149f04f9019e3daabfc68a85d57d 100644 --- a/library/src/tensor_operation_instance/gpu/gemm/device_gemm_dl_i8_i8_i8_km_nk_mn_irregular_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/gemm/device_gemm_dl_i8_i8_i8_km_nk_mn_irregular_instance.cpp @@ -8,7 +8,7 @@ #include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" #include "ck/tensor_operation/gpu/device/impl/device_gemm_dl.hpp" #include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" - +#ifdef CK_ENABLE_INT8 namespace ck { namespace tensor_operation { namespace device { @@ -78,3 +78,4 @@ void add_device_gemm_dl_i8_i8_i8_km_nk_mn_irregular_instances( } // namespace device } // namespace tensor_operation } // namespace ck +#endif diff --git a/library/src/tensor_operation_instance/gpu/gemm/device_gemm_dl_i8_i8_i8_mk_kn_mn_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm/device_gemm_dl_i8_i8_i8_mk_kn_mn_instance.cpp index d1173095fc1fe9dea88fa1878e83cb291568cac2..8969983356191df9abaad43dd096842f96dca53c 100644 --- a/library/src/tensor_operation_instance/gpu/gemm/device_gemm_dl_i8_i8_i8_mk_kn_mn_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/gemm/device_gemm_dl_i8_i8_i8_mk_kn_mn_instance.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #include @@ -8,7 +8,7 @@ #include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" #include "ck/tensor_operation/gpu/device/impl/device_gemm_dl.hpp" #include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" - +#ifdef CK_ENABLE_INT8 namespace ck { namespace tensor_operation { namespace device { @@ -80,3 +80,4 @@ void add_device_gemm_dl_i8_i8_i8_mk_kn_mn_instances( } // namespace device } // namespace tensor_operation } // namespace ck +#endif diff --git a/library/src/tensor_operation_instance/gpu/gemm/device_gemm_dl_i8_i8_i8_mk_kn_mn_irregular_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm/device_gemm_dl_i8_i8_i8_mk_kn_mn_irregular_instance.cpp index 27397527bc973c56bf15a8c35e86dec89184959c..745a4bb31da6d30008469daa4e257fb6e95aca1b 100644 --- a/library/src/tensor_operation_instance/gpu/gemm/device_gemm_dl_i8_i8_i8_mk_kn_mn_irregular_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/gemm/device_gemm_dl_i8_i8_i8_mk_kn_mn_irregular_instance.cpp @@ -8,7 +8,7 @@ #include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" #include "ck/tensor_operation/gpu/device/impl/device_gemm_dl.hpp" #include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" - +#ifdef CK_ENABLE_INT8 namespace ck { namespace tensor_operation { namespace device { @@ -78,3 +78,4 @@ void add_device_gemm_dl_i8_i8_i8_mk_kn_mn_irregular_instances( } // namespace device } // namespace tensor_operation } // namespace ck +#endif diff --git a/library/src/tensor_operation_instance/gpu/gemm/device_gemm_dl_i8_i8_i8_mk_nk_mn_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm/device_gemm_dl_i8_i8_i8_mk_nk_mn_instance.cpp index efda345a8d42446ef5f9ceb7992a07cd9338e083..2bda30f82cb28aba19db635f88f28a4ebf456cb4 100644 --- a/library/src/tensor_operation_instance/gpu/gemm/device_gemm_dl_i8_i8_i8_mk_nk_mn_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/gemm/device_gemm_dl_i8_i8_i8_mk_nk_mn_instance.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #include @@ -8,7 +8,7 @@ #include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" #include "ck/tensor_operation/gpu/device/impl/device_gemm_dl.hpp" #include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" - +#ifdef CK_ENABLE_INT8 namespace ck { namespace tensor_operation { namespace device { @@ -80,3 +80,4 @@ void add_device_gemm_dl_i8_i8_i8_mk_nk_mn_instances( } // namespace device } // namespace tensor_operation } // namespace ck +#endif diff --git a/library/src/tensor_operation_instance/gpu/gemm/device_gemm_dl_i8_i8_i8_mk_nk_mn_irregular_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm/device_gemm_dl_i8_i8_i8_mk_nk_mn_irregular_instance.cpp index b99f3f2b63902c4b832c86dcdc3f74376d972dbe..b5e8b8ecd66037ac489a0e44a7cddb1a9ac80b0f 100644 --- a/library/src/tensor_operation_instance/gpu/gemm/device_gemm_dl_i8_i8_i8_mk_nk_mn_irregular_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/gemm/device_gemm_dl_i8_i8_i8_mk_nk_mn_irregular_instance.cpp @@ -8,7 +8,7 @@ #include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" #include "ck/tensor_operation/gpu/device/impl/device_gemm_dl.hpp" #include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" - +#ifdef CK_ENABLE_INT8 namespace ck { namespace tensor_operation { namespace device { @@ -78,3 +78,4 @@ void add_device_gemm_dl_i8_i8_i8_mk_nk_mn_irregular_instances( } // namespace device } // namespace tensor_operation } // namespace ck +#endif diff --git a/library/src/tensor_operation_instance/gpu/gemm/device_gemm_dpp_f16_f16_f16_km_kn_mn_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm/device_gemm_dpp_f16_f16_f16_km_kn_mn_instance.cpp new file mode 100644 index 0000000000000000000000000000000000000000..80da6d4c3b43e5c48ab8758981aae9773dc25b9f --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/gemm/device_gemm_dpp_f16_f16_f16_km_kn_mn_instance.cpp @@ -0,0 +1,58 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_gemm_dpp.hpp" +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +using F16 = ck::half_t; +using F32 = float; + +using Row = ck::tensor_layout::gemm::RowMajor; +using Col = ck::tensor_layout::gemm::ColumnMajor; + +template +using S = ck::Sequence; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; + +static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default; + +// Compilation parameters for a[k, m] * b[k, n] = c[m, n] +// clang-format off +using device_gemm_dpp_f16_f16_f16_km_kn_mn_instances = std::tuple< + // ########| AData| BData| CData| AccData| ALayout| BLayout| CLayout| A| B| C| GEMM| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MDpp| NDpp| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CThreadTransfer| CThreadTransfer| + // ########| Type| Type| Type| Type| | | | Elementwise| Elementwise| Elementwise| Specialization| Size| Block| Block| Block| | | Dpp| Dpp| PerWave| PerWave| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| SrcDstVectorDim| DstScalar| + // ########| | | | | | | | Operation| Operation| Operation| | | | | | | | | | | | Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | | PerVector| + // ########| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceGemmDpp< F16, F16, F16, F32, Col, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 128, 128, 64, 4, 4, 16, 16, 2, 4, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 4, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 4, true, 5, 1>, + DeviceGemmDpp< F16, F16, F16, F32, Col, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 128, 128, 64, 4, 4, 32, 8, 2, 4, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 4, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 4, true, 5, 1>, + DeviceGemmDpp< F16, F16, F16, F32, Col, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 128, 64, 64, 64, 4, 4, 32, 8, 2, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 4, true, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 4, true, 5, 1>, + DeviceGemmDpp< F16, F16, F16, F32, Col, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 128, 32, 32, 32, 4, 4, 32, 8, 1, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 4, true, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 4, true, 5, 1>, + DeviceGemmDpp< F16, F16, F16, F32, Col, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 64, 64, 64, 64, 4, 4, 32, 8, 2, 4, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, true, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, true, 5, 1>, + DeviceGemmDpp< F16, F16, F16, F32, Col, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 32, 32, 32, 32, 4, 4, 32, 8, 1, 4, S<4, 8, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, true, S<4, 8, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, true, 5, 1>, + DeviceGemmDpp< F16, F16, F16, F32, Col, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 32, 16, 16, 16, 4, 4, 16, 16, 1, 1, S<2, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 4, true, S<2, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 4, true, 5, 1> + >; +// clang-format on + +void add_device_gemm_dpp_f16_f16_f16_km_kn_mn_instances( + std::vector>>& + instances) +{ + add_device_operation_instances(instances, device_gemm_dpp_f16_f16_f16_km_kn_mn_instances{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/gemm/device_gemm_dpp_f16_f16_f16_km_nk_mn_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm/device_gemm_dpp_f16_f16_f16_km_nk_mn_instance.cpp new file mode 100644 index 0000000000000000000000000000000000000000..97b85fd1a0d6a59fcdd50133a5d12ff9ac3f62b8 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/gemm/device_gemm_dpp_f16_f16_f16_km_nk_mn_instance.cpp @@ -0,0 +1,58 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_gemm_dpp.hpp" +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +using F16 = ck::half_t; +using F32 = float; + +using Row = ck::tensor_layout::gemm::RowMajor; +using Col = ck::tensor_layout::gemm::ColumnMajor; + +template +using S = ck::Sequence; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; + +static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default; + +// Compilation parameters for a[k, m] * b[n, k] = c[m, n] +// clang-format off +using device_gemm_dpp_f16_f16_f16_km_nk_mn_instances = std::tuple< + // ########| AData| BData| CData| AccData| ALayout| BLayout| CLayout| A| B| C| GEMM| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MDpp| NDpp| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CThreadTransfer| CThreadTransfer| + // ########| Type| Type| Type| Type| | | | Elementwise| Elementwise| Elementwise| Specialization| Size| Block| Block| Block| | | Dpp| Dpp| PerWave| PerWave| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| SrcDstVectorDim| DstScalar| + // ########| | | | | | | | Operation| Operation| Operation| | | | | | | | | | | | Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | | PerVector| + // ########| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceGemmDpp< F16, F16, F16, F32, Col, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 128, 128, 64, 4, 8, 16, 16, 2, 4, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 4, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 5, 1>, + DeviceGemmDpp< F16, F16, F16, F32, Col, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 128, 128, 64, 4, 8, 32, 8, 2, 4, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 4, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 5, 1>, + DeviceGemmDpp< F16, F16, F16, F32, Col, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 128, 64, 64, 64, 4, 8, 32, 8, 2, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 4, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 5, 1>, + DeviceGemmDpp< F16, F16, F16, F32, Col, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 128, 32, 32, 32, 4, 8, 32, 8, 1, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 4, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 5, 1>, + DeviceGemmDpp< F16, F16, F16, F32, Col, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 64, 64, 64, 64, 4, 8, 32, 8, 2, 4, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 5, 1>, + DeviceGemmDpp< F16, F16, F16, F32, Col, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 32, 32, 32, 32, 4, 8, 32, 8, 1, 4, S<4, 8, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, true, S<4, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 5, 1>, + DeviceGemmDpp< F16, F16, F16, F32, Col, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 32, 16, 16, 16, 4, 8, 16, 16, 1, 1, S<2, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 4, true, S<2, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 5, 1> + >; +// clang-format on + +void add_device_gemm_dpp_f16_f16_f16_km_nk_mn_instances( + std::vector>>& + instances) +{ + add_device_operation_instances(instances, device_gemm_dpp_f16_f16_f16_km_nk_mn_instances{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/gemm/device_gemm_dpp_f16_f16_f16_mk_kn_mn_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm/device_gemm_dpp_f16_f16_f16_mk_kn_mn_instance.cpp new file mode 100644 index 0000000000000000000000000000000000000000..370ebfccc4cd428a38a022604b348f6c84a1b2ef --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/gemm/device_gemm_dpp_f16_f16_f16_mk_kn_mn_instance.cpp @@ -0,0 +1,57 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_gemm_dpp.hpp" +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +using F16 = ck::half_t; +using F32 = float; + +using Row = ck::tensor_layout::gemm::RowMajor; + +template +using S = ck::Sequence; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; + +static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default; + +// Compilation parameters for a[m, k] * b[k, n] = c[m, n] +// clang-format off +using device_gemm_dpp_f16_f16_f16_mk_kn_mn_instances = std::tuple< + // ########| AData| BData| CData| AccData| ALayout| BLayout| CLayout| A| B| C| GEMM| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MDpp| NDpp| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CThreadTransfer| CThreadTransfer| + // ########| Type| Type| Type| Type| | | | Elementwise| Elementwise| Elementwise| Specialization| Size| Block| Block| Block| | | Dpp| Dpp| PerWave| PerWave| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| SrcDstVectorDim| DstScalar| + // ########| | | | | | | | Operation| Operation| Operation| | | | | | | | | | | | Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | | PerVector| + // ########| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceGemmDpp< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 128, 128, 64, 8, 4, 16, 16, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 4, true, 5, 1>, + DeviceGemmDpp< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 128, 128, 64, 8, 4, 32, 8, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 4, true, 5, 1>, + DeviceGemmDpp< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 128, 64, 64, 64, 8, 4, 32, 8, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 4, true, 5, 1>, + DeviceGemmDpp< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 128, 32, 32, 32, 8, 4, 32, 8, 1, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 4, true, 5, 1>, + DeviceGemmDpp< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 64, 64, 64, 64, 8, 4, 32, 8, 2, 4, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, true, 5, 1>, + DeviceGemmDpp< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 32, 32, 32, 32, 8, 4, 32, 8, 1, 4, S<4, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 8, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, true, 5, 1>, + DeviceGemmDpp< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 32, 16, 16, 16, 8, 4, 16, 16, 1, 1, S<2, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<2, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 4, true, 5, 1> + >; +// clang-format on + +void add_device_gemm_dpp_f16_f16_f16_mk_kn_mn_instances( + std::vector>>& + instances) +{ + add_device_operation_instances(instances, device_gemm_dpp_f16_f16_f16_mk_kn_mn_instances{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/gemm/device_gemm_dpp_f16_f16_f16_mk_nk_mn_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm/device_gemm_dpp_f16_f16_f16_mk_nk_mn_instance.cpp new file mode 100644 index 0000000000000000000000000000000000000000..6053275492ecd2fadddb53cc181e899ce3381a6e --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/gemm/device_gemm_dpp_f16_f16_f16_mk_nk_mn_instance.cpp @@ -0,0 +1,58 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_gemm_dpp.hpp" +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +using F16 = ck::half_t; +using F32 = float; + +using Row = ck::tensor_layout::gemm::RowMajor; +using Col = ck::tensor_layout::gemm::ColumnMajor; + +template +using S = ck::Sequence; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; + +static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default; + +// Compilation parameters for a[m, k] * b[n, k] = c[m, n] +// clang-format off +using device_gemm_dpp_f16_f16_f16_mk_nk_mn_instances = std::tuple< + // ########| AData| BData| CData| AccData| ALayout| BLayout| CLayout| A| B| C| GEMM| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MDpp| NDpp| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CThreadTransfer| CThreadTransfer| + // ########| Type| Type| Type| Type| | | | Elementwise| Elementwise| Elementwise| Specialization| Size| Block| Block| Block| | | Dpp| Dpp| PerWave| PerWave| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| SrcDstVectorDim| DstScalar| + // ########| | | | | | | | Operation| Operation| Operation| | | | | | | | | | | | Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | | PerVector| + // ########| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceGemmDpp< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 128, 128, 64, 8, 8, 16, 16, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 5, 1>, + DeviceGemmDpp< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 128, 128, 64, 8, 8, 32, 8, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 5, 1>, + DeviceGemmDpp< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 128, 64, 64, 64, 8, 8, 32, 8, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 5, 1>, + DeviceGemmDpp< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 128, 32, 32, 32, 8, 8, 32, 8, 1, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 5, 1>, + DeviceGemmDpp< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 64, 64, 64, 64, 8, 8, 32, 8, 2, 4, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 5, 1>, + DeviceGemmDpp< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 32, 32, 32, 32, 8, 8, 32, 8, 1, 4, S<4, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 5, 1>, + DeviceGemmDpp< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 32, 16, 16, 16, 8, 8, 16, 16, 1, 1, S<2, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<2, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 5, 1> + >; +// clang-format on + +void add_device_gemm_dpp_f16_f16_f16_mk_nk_mn_instances( + std::vector>>& + instances) +{ + add_device_operation_instances(instances, device_gemm_dpp_f16_f16_f16_mk_nk_mn_instances{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_c_shuffle_2_stage_f16_f16_f16_mk_nk_mn_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_c_shuffle_2_stage_f16_f16_f16_mk_nk_mn_instance.cpp index 03eebf4ec3062cf5f85486288673fb96fb2761b1..2e884dfc8afba8d6c6ea88a8e3159c7a62660364 100644 --- a/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_c_shuffle_2_stage_f16_f16_f16_mk_nk_mn_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_c_shuffle_2_stage_f16_f16_f16_mk_nk_mn_instance.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #include diff --git a/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_c_shuffle_bf16_bf16_bf16_km_kn_mn_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_c_shuffle_bf16_bf16_bf16_km_kn_mn_instance.cpp index 5d8de04cd9fa66931dde4ed0494fc429659506c3..2ca29b1e6f0791154d16ec1b5dd134d8d9c77c37 100644 --- a/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_c_shuffle_bf16_bf16_bf16_km_kn_mn_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_c_shuffle_bf16_bf16_bf16_km_kn_mn_instance.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #include diff --git a/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_c_shuffle_bf16_bf16_bf16_km_nk_mn_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_c_shuffle_bf16_bf16_bf16_km_nk_mn_instance.cpp index 7b12b7cf1f42da6d8c321c1ac121166753b7ff59..706076098da912ea9e5a844ccb2797cbc850a01b 100644 --- a/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_c_shuffle_bf16_bf16_bf16_km_nk_mn_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_c_shuffle_bf16_bf16_bf16_km_nk_mn_instance.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #include diff --git a/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_c_shuffle_bf16_bf16_bf16_mk_kn_mn_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_c_shuffle_bf16_bf16_bf16_mk_kn_mn_instance.cpp index 730ffd4633de3c72cb6ab932126a18d8798418dd..5ac458a7b985d16faf90a53866ec4f9d2bc9c6b0 100644 --- a/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_c_shuffle_bf16_bf16_bf16_mk_kn_mn_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_c_shuffle_bf16_bf16_bf16_mk_kn_mn_instance.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #include diff --git a/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_c_shuffle_bf16_bf16_bf16_mk_nk_mn_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_c_shuffle_bf16_bf16_bf16_mk_nk_mn_instance.cpp index 619473ff0ce40a5150c59efdfd62bc905d7ce83f..a64412544721237542c729d0f9581acb39917555 100644 --- a/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_c_shuffle_bf16_bf16_bf16_mk_nk_mn_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_c_shuffle_bf16_bf16_bf16_mk_nk_mn_instance.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #include diff --git a/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_c_shuffle_f16_f16_f16_km_kn_mn_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_c_shuffle_f16_f16_f16_km_kn_mn_instance.cpp index 8e06f9d26b4d3cf8a22259a54e103c7b8ce76190..44b684823391fd83caaf4b55a1ef41c19fe5b3c1 100644 --- a/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_c_shuffle_f16_f16_f16_km_kn_mn_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_c_shuffle_f16_f16_f16_km_kn_mn_instance.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #include diff --git a/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_c_shuffle_f16_f16_f16_km_nk_mn_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_c_shuffle_f16_f16_f16_km_nk_mn_instance.cpp index f9458b7483c5c29011b4aa063061436da23f9e36..23176269c2c84808ad05a874259694f22c6064d1 100644 --- a/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_c_shuffle_f16_f16_f16_km_nk_mn_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_c_shuffle_f16_f16_f16_km_nk_mn_instance.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #include diff --git a/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_c_shuffle_f16_f16_f16_mk_kn_mn_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_c_shuffle_f16_f16_f16_mk_kn_mn_instance.cpp index 77a03b746cc12e5f8cea8772c4432c9fb4e24110..31a9abe53a2312c25ebeb5157e81998054d2345b 100644 --- a/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_c_shuffle_f16_f16_f16_mk_kn_mn_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_c_shuffle_f16_f16_f16_mk_kn_mn_instance.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #include diff --git a/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_c_shuffle_f16_f16_f16_mk_nk_mn_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_c_shuffle_f16_f16_f16_mk_nk_mn_instance.cpp index ef8d7d4e40ef4de46857db4932e109780e52f044..201fd93110b604b1c6df849fcc5c88ff98a93bea 100644 --- a/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_c_shuffle_f16_f16_f16_mk_nk_mn_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_c_shuffle_f16_f16_f16_mk_nk_mn_instance.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #include diff --git a/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_c_shuffle_f32_f32_f32_km_kn_mn_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_c_shuffle_f32_f32_f32_km_kn_mn_instance.cpp index cb65cc7b68b359ec015af3adcd3a7866983d492b..5d489b207efdf62de64d99469acfdbe49a13859f 100644 --- a/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_c_shuffle_f32_f32_f32_km_kn_mn_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_c_shuffle_f32_f32_f32_km_kn_mn_instance.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #include diff --git a/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_c_shuffle_f32_f32_f32_km_nk_mn_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_c_shuffle_f32_f32_f32_km_nk_mn_instance.cpp index 5b1014ed827d08f275b5da4c20abe14f153b9950..e09480d57bc20e54697fcd3c5ae761a66fda678b 100644 --- a/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_c_shuffle_f32_f32_f32_km_nk_mn_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_c_shuffle_f32_f32_f32_km_nk_mn_instance.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #include diff --git a/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_c_shuffle_f32_f32_f32_mk_kn_mn_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_c_shuffle_f32_f32_f32_mk_kn_mn_instance.cpp index e6f6add8bf45e98f6e9607b48a924e00ef642323..34065c334d1bee6064718a8dbac6f0b6f51c5826 100644 --- a/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_c_shuffle_f32_f32_f32_mk_kn_mn_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_c_shuffle_f32_f32_f32_mk_kn_mn_instance.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #include diff --git a/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_c_shuffle_f32_f32_f32_mk_nk_mn_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_c_shuffle_f32_f32_f32_mk_nk_mn_instance.cpp index 80b3d03da26c49136c8220f8a433d8798ad751e1..95d7777a79371c4d9dfa1ec6ff09b16dc55a3807 100644 --- a/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_c_shuffle_f32_f32_f32_mk_nk_mn_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_c_shuffle_f32_f32_f32_mk_nk_mn_instance.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #include diff --git a/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_c_shuffle_i8_i8_i8_km_kn_mn_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_c_shuffle_i8_i8_i8_km_kn_mn_instance.cpp index 93b3df1e572604e70d059c6cb622f8143eb2a92a..d02fb8f70beef7a3cb6359277a0136ecf28d9c75 100644 --- a/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_c_shuffle_i8_i8_i8_km_kn_mn_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_c_shuffle_i8_i8_i8_km_kn_mn_instance.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #include @@ -8,7 +8,7 @@ #include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" #include "ck/tensor_operation/gpu/device/impl/device_gemm_xdl_cshuffle.hpp" #include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" - +#ifdef CK_ENABLE_INT8 namespace ck { namespace tensor_operation { namespace device { @@ -66,3 +66,4 @@ void add_device_gemm_xdl_c_shuffle_i8_i8_i8_km_kn_mn_instances( } // namespace device } // namespace tensor_operation } // namespace ck +#endif diff --git a/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_c_shuffle_i8_i8_i8_km_nk_mn_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_c_shuffle_i8_i8_i8_km_nk_mn_instance.cpp index f10365d892455e3e36f5e05ec94b7a4f577c1fbc..abf79262e6ac85dc6828616f90851fd048cf5b52 100644 --- a/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_c_shuffle_i8_i8_i8_km_nk_mn_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_c_shuffle_i8_i8_i8_km_nk_mn_instance.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #include @@ -8,7 +8,7 @@ #include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" #include "ck/tensor_operation/gpu/device/impl/device_gemm_xdl_cshuffle.hpp" #include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" - +#ifdef CK_ENABLE_INT8 namespace ck { namespace tensor_operation { namespace device { @@ -66,3 +66,4 @@ void add_device_gemm_xdl_c_shuffle_i8_i8_i8_km_nk_mn_instances( } // namespace device } // namespace tensor_operation } // namespace ck +#endif diff --git a/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_c_shuffle_i8_i8_i8_mk_kn_mn_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_c_shuffle_i8_i8_i8_mk_kn_mn_instance.cpp index a7a9eb62cef20c05cee9ff76e5d1c0b21fb95ab9..5da89c3421cf4da725f53b0d41619e077668ed6d 100644 --- a/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_c_shuffle_i8_i8_i8_mk_kn_mn_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_c_shuffle_i8_i8_i8_mk_kn_mn_instance.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #include @@ -8,7 +8,7 @@ #include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" #include "ck/tensor_operation/gpu/device/impl/device_gemm_xdl_cshuffle.hpp" #include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" - +#ifdef CK_ENABLE_INT8 namespace ck { namespace tensor_operation { namespace device { @@ -66,3 +66,4 @@ void add_device_gemm_xdl_c_shuffle_i8_i8_i8_mk_kn_mn_instances( } // namespace device } // namespace tensor_operation } // namespace ck +#endif diff --git a/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_c_shuffle_i8_i8_i8_mk_nk_mn_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_c_shuffle_i8_i8_i8_mk_nk_mn_instance.cpp index 9fb45b00365896be88e6aff214922c51325a4d65..0ade7a61cedc817a5afe79cbdc315aa9caf324e0 100644 --- a/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_c_shuffle_i8_i8_i8_mk_nk_mn_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_c_shuffle_i8_i8_i8_mk_nk_mn_instance.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #include @@ -8,7 +8,7 @@ #include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" #include "ck/tensor_operation/gpu/device/impl/device_gemm_xdl_cshuffle.hpp" #include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" - +#ifdef CK_ENABLE_INT8 namespace ck { namespace tensor_operation { namespace device { @@ -63,3 +63,4 @@ void add_device_gemm_xdl_c_shuffle_i8_i8_i8_mk_nk_mn_instances( } // namespace device } // namespace tensor_operation } // namespace ck +#endif diff --git a/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_f16_f16_f16/common.hpp b/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_f16_f16_f16/common.hpp new file mode 100644 index 0000000000000000000000000000000000000000..de96a22630c88ccfe9a4b1a2f84b01e0d4751613 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_f16_f16_f16/common.hpp @@ -0,0 +1,43 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck/ck.hpp" +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_gemm_xdl.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/utility/data_type.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +using F16 = ck::half_t; +using F32 = float; + +using Row = ck::tensor_layout::gemm::RowMajor; +using Col = ck::tensor_layout::gemm::ColumnMajor; + +template +using S = ck::Sequence; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; + +static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default; +static constexpr auto GemmMNPadding = ck::tensor_operation::device::GemmSpecialization::MNPadding; + +using InstanceNT = DeviceGemm; +using InstanceNN = DeviceGemm; +using InstanceTT = DeviceGemm; +using InstanceTN = DeviceGemm; + +template +using OwnerList = std::vector>; + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_f16_f16_f16/km_kn_mn_add_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_f16_f16_f16/km_kn_mn_add_instance.cpp new file mode 100644 index 0000000000000000000000000000000000000000..d887b16e6015ae2bdf0b4a157e6c5f8bef75b5d7 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_f16_f16_f16/km_kn_mn_add_instance.cpp @@ -0,0 +1,38 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#include "common.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +using Instance = InstanceNT; +using Instances = OwnerList; + +void add_device_gemm_xdl_f16_f16_f16_km_kn_mn_default_pipeline_v1_instances(Instances&); +void add_device_gemm_xdl_f16_f16_f16_km_kn_mn_interwave_pipeline_v1_instances(Instances&); +void add_device_gemm_xdl_f16_f16_f16_km_kn_mn_default_pipeline_v2_instances(Instances&); +void add_device_gemm_xdl_f16_f16_f16_km_kn_mn_default_pipeline_v2_opt_instances(Instances&); + +void add_device_gemm_xdl_f16_f16_f16_km_kn_mn_irregular_default_pipeline_v1_instances(Instances&); +void add_device_gemm_xdl_f16_f16_f16_km_kn_mn_irregular_interwave_pipeline_v1_instances(Instances&); +void add_device_gemm_xdl_f16_f16_f16_km_kn_mn_irregular_default_pipeline_v2_instances(Instances&); + +void add_device_gemm_xdl_f16_f16_f16_km_kn_mn_instances(Instances& instances) +{ + add_device_gemm_xdl_f16_f16_f16_km_kn_mn_default_pipeline_v1_instances(instances); + add_device_gemm_xdl_f16_f16_f16_km_kn_mn_interwave_pipeline_v1_instances(instances); + add_device_gemm_xdl_f16_f16_f16_km_kn_mn_default_pipeline_v2_instances(instances); + add_device_gemm_xdl_f16_f16_f16_km_kn_mn_default_pipeline_v2_opt_instances(instances); + + add_device_gemm_xdl_f16_f16_f16_km_kn_mn_irregular_default_pipeline_v1_instances(instances); + add_device_gemm_xdl_f16_f16_f16_km_kn_mn_irregular_interwave_pipeline_v1_instances(instances); + add_device_gemm_xdl_f16_f16_f16_km_kn_mn_irregular_default_pipeline_v2_instances(instances); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_f16_f16_f16/km_kn_mn_default_pipeline_v1_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_f16_f16_f16/km_kn_mn_default_pipeline_v1_instance.cpp new file mode 100644 index 0000000000000000000000000000000000000000..81fedd50f0b2c790fe93de9891dc0f952d9cb798 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_f16_f16_f16/km_kn_mn_default_pipeline_v1_instance.cpp @@ -0,0 +1,40 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#include "common.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +// Compilation parameters for a[k, m] * b[k, n] = c[m, n] +using Instances = + std::tuple< + // clang-format off + // pipeline v1, 1 wave + //##########| AData| BData| CData| AccData| ALayout| BLayout| CLayout| A| B| C| GEMM| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CThreadTransfer| CThreadTransfer| NumPrefetch| LoopScheduler| Pipeline| + //##########| Type| Type| Type| Type| | | | Elementwise| Elementwise| Elementwise|Specialization| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| SrcDstVectorDim| DstScalar| | | | + //##########| | | | | | | | Operation| Operation| Operation| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | | PerVector| | | | + //##########| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceGemmXdl< F16, F16, F16, F32, Col, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 256, 128, 4, 8, 32, 32, 4, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, 7, 1, 1, LoopScheduler::Default, PipelineVersion::v1>, + DeviceGemmXdl< F16, F16, F16, F32, Col, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 128, 256, 4, 8, 32, 32, 2, 4, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, true, 7, 1, 1, LoopScheduler::Default, PipelineVersion::v1>, + DeviceGemmXdl< F16, F16, F16, F32, Col, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 128, 128, 128, 4, 8, 32, 32, 4, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, true, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, true, 7, 1, 1, LoopScheduler::Default, PipelineVersion::v1>, + DeviceGemmXdl< F16, F16, F16, F32, Col, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 128, 128, 4, 8, 32, 32, 2, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, 7, 1, 1, LoopScheduler::Default, PipelineVersion::v1>, + DeviceGemmXdl< F16, F16, F16, F32, Col, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 128, 128, 64, 4, 8, 32, 32, 2, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, true, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, 7, 1, 1, LoopScheduler::Default, PipelineVersion::v1>, + DeviceGemmXdl< F16, F16, F16, F32, Col, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 128, 64, 128, 4, 8, 32, 32, 2, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, true, 7, 1, 1, LoopScheduler::Default, PipelineVersion::v1>, + DeviceGemmXdl< F16, F16, F16, F32, Col, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 128, 64, 4, 8, 32, 32, 2, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, true, 7, 1, 1, LoopScheduler::Default, PipelineVersion::v1>, + DeviceGemmXdl< F16, F16, F16, F32, Col, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 64, 128, 4, 8, 32, 32, 1, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, 7, 1, 1, LoopScheduler::Default, PipelineVersion::v1> + // clang-format on + >; + +void add_device_gemm_xdl_f16_f16_f16_km_kn_mn_default_pipeline_v1_instances( + OwnerList& instances) +{ + add_device_operation_instances(instances, Instances{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_f16_f16_f16/km_kn_mn_default_pipeline_v2_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_f16_f16_f16/km_kn_mn_default_pipeline_v2_instance.cpp new file mode 100644 index 0000000000000000000000000000000000000000..5a0c52c2df9efde9d755420ffa34bf21bd1522f1 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_f16_f16_f16/km_kn_mn_default_pipeline_v2_instance.cpp @@ -0,0 +1,42 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#include "common.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +// Compilation parameters for a[k, m] * b[k, n] = c[m, n] +using Instances = + std::tuple< +// clang-format off +#if CK_EXPERIMENTAL_PIPELINE_V2_INSTANCES + // pipeline v2, 1 wave + //##########| AData| BData| CData| AccData| ALayout| BLayout| CLayout| A| B| C| GEMM| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CThreadTransfer| CThreadTransfer| NumPrefetch| LoopScheduler| Pipeline| + //##########| Type| Type| Type| Type| | | | Elementwise| Elementwise| Elementwise|Specialization| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| SrcDstVectorDim| DstScalar| | | | + //##########| | | | | | | | Operation| Operation| Operation| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | | PerVector| | | | + //##########| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceGemmXdl< F16, F16, F16, F32, Col, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 256, 128, 4, 8, 32, 32, 4, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, 7, 1, 1, LoopScheduler::Default, PipelineVersion::v2>, + DeviceGemmXdl< F16, F16, F16, F32, Col, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 128, 256, 4, 8, 32, 32, 2, 4, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, true, 7, 1, 1, LoopScheduler::Default, PipelineVersion::v2>, + DeviceGemmXdl< F16, F16, F16, F32, Col, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 128, 128, 128, 4, 8, 32, 32, 4, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, true, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, true, 7, 1, 1, LoopScheduler::Default, PipelineVersion::v2>, + DeviceGemmXdl< F16, F16, F16, F32, Col, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 128, 128, 4, 8, 32, 32, 2, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, 7, 1, 1, LoopScheduler::Default, PipelineVersion::v2>, + DeviceGemmXdl< F16, F16, F16, F32, Col, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 128, 128, 64, 4, 8, 32, 32, 2, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, true, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, 7, 1, 1, LoopScheduler::Default, PipelineVersion::v2>, + DeviceGemmXdl< F16, F16, F16, F32, Col, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 128, 64, 128, 4, 8, 32, 32, 2, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, true, 7, 1, 1, LoopScheduler::Default, PipelineVersion::v2>, + DeviceGemmXdl< F16, F16, F16, F32, Col, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 128, 64, 4, 8, 32, 32, 2, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, true, 7, 1, 1, LoopScheduler::Default, PipelineVersion::v2>, + DeviceGemmXdl< F16, F16, F16, F32, Col, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 64, 128, 4, 8, 32, 32, 1, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, 7, 1, 1, LoopScheduler::Default, PipelineVersion::v2> +#endif + // clang-format on + >; + +void add_device_gemm_xdl_f16_f16_f16_km_kn_mn_default_pipeline_v2_instances( + OwnerList& instances) +{ + add_device_operation_instances(instances, Instances{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_f16_f16_f16/km_kn_mn_default_pipeline_v2_opt_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_f16_f16_f16/km_kn_mn_default_pipeline_v2_opt_instance.cpp new file mode 100644 index 0000000000000000000000000000000000000000..59ffb80bd4b6fcbbd7850ec017c526f41a87f87c --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_f16_f16_f16/km_kn_mn_default_pipeline_v2_opt_instance.cpp @@ -0,0 +1,35 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#include "common.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +// Compilation parameters for a[k, m] * b[k, n] = c[m, n] +using Instances = + std::tuple< +// clang-format off +#if CK_EXPERIMENTAL_PIPELINE_V2_INSTANCES + // pipeline v2, 1 wave + //##########| AData| BData| CData| AccData| ALayout| BLayout| CLayout| A| B| C| GEMM| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CThreadTransfer| CThreadTransfer| NumPrefetch| LoopScheduler| Pipeline| + //##########| Type| Type| Type| Type| | | | Elementwise| Elementwise| Elementwise|Specialization| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| SrcDstVectorDim| DstScalar| | | | + //##########| | | | | | | | Operation| Operation| Operation| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | | PerVector| | | | + //##########| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceGemmXdl< F16, F16, F16, F32, Col, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 64, 128, 8, 8, 32, 32, 1, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, 7, 1, 1, LoopScheduler::Default, PipelineVersion::v2> +#endif + // clang-format on + >; + +void add_device_gemm_xdl_f16_f16_f16_km_kn_mn_default_pipeline_v2_opt_instances( + OwnerList& instances) +{ + add_device_operation_instances(instances, Instances{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_f16_f16_f16/km_kn_mn_interwave_pipeline_v1_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_f16_f16_f16/km_kn_mn_interwave_pipeline_v1_instance.cpp new file mode 100644 index 0000000000000000000000000000000000000000..a64424e8acb300fe51d2c44c0ebd02a92e731c6d --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_f16_f16_f16/km_kn_mn_interwave_pipeline_v1_instance.cpp @@ -0,0 +1,42 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#include "common.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +// Compilation parameters for a[k, m] * b[k, n] = c[m, n] +using Instances = + std::tuple< +// clang-format off +#if CK_EXPERIMENTAL_INTER_WAVE_INSTANCES + // pipeline v1, 2 waves + //##########| AData| BData| CData| AccData| ALayout| BLayout| CLayout| A| B| C| GEMM| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CThreadTransfer| CThreadTransfer| NumPrefetch| LoopScheduler| Pipeline| + //##########| Type| Type| Type| Type| | | | Elementwise| Elementwise| Elementwise|Specialization| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| SrcDstVectorDim| DstScalar| | | | + //##########| | | | | | | | Operation| Operation| Operation| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | | PerVector| | | | + //##########| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceGemmXdl< F16, F16, F16, F32, Col, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 256, 128, 4, 8, 32, 32, 4, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, 7, 1, 1, LoopScheduler::Interwave, PipelineVersion::v1>, + DeviceGemmXdl< F16, F16, F16, F32, Col, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 128, 256, 4, 8, 32, 32, 2, 4, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, true, 7, 1, 1, LoopScheduler::Interwave, PipelineVersion::v1>, + DeviceGemmXdl< F16, F16, F16, F32, Col, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 128, 128, 128, 4, 8, 32, 32, 4, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, true, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, true, 7, 1, 1, LoopScheduler::Interwave, PipelineVersion::v1>, + DeviceGemmXdl< F16, F16, F16, F32, Col, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 128, 128, 4, 8, 32, 32, 2, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, 7, 1, 1, LoopScheduler::Interwave, PipelineVersion::v1>, + DeviceGemmXdl< F16, F16, F16, F32, Col, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 128, 128, 64, 4, 8, 32, 32, 2, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, true, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, 7, 1, 1, LoopScheduler::Interwave, PipelineVersion::v1>, + DeviceGemmXdl< F16, F16, F16, F32, Col, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 128, 64, 128, 4, 8, 32, 32, 2, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, true, 7, 1, 1, LoopScheduler::Interwave, PipelineVersion::v1>, + DeviceGemmXdl< F16, F16, F16, F32, Col, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 128, 64, 4, 8, 32, 32, 2, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, true, 7, 1, 1, LoopScheduler::Interwave, PipelineVersion::v1>, + DeviceGemmXdl< F16, F16, F16, F32, Col, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 64, 128, 4, 8, 32, 32, 1, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, 7, 1, 1, LoopScheduler::Interwave, PipelineVersion::v1> +#endif + // clang-format on + >; + +void add_device_gemm_xdl_f16_f16_f16_km_kn_mn_interwave_pipeline_v1_instances( + OwnerList& instances) +{ + add_device_operation_instances(instances, Instances{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_f16_f16_f16/km_kn_mn_irregular_default_pipeline_v1_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_f16_f16_f16/km_kn_mn_irregular_default_pipeline_v1_instance.cpp new file mode 100644 index 0000000000000000000000000000000000000000..e313a790e10b0f0c94c491e0409d630cb6441938 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_f16_f16_f16/km_kn_mn_irregular_default_pipeline_v1_instance.cpp @@ -0,0 +1,32 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#include "common.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +// irregular tile size +using Instances = std::tuple< + // clang-format off + // pipeline v1, 1 wave + //###########| AData| BData| CData| AccData| ALayout| BLayout| CLayout| A| B| C| GEMM| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CThreadTransfer| CThreadTransfer| NumPrefetch| LoopScheduler| Pipeline| + //###########| Type| Type| Type| Type| | | | Elementwise| Elementwise| Elementwise|Specialization| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| SrcDstVectorDim| DstScalar| | | | + //###########| | | | | | | | Operation| Operation| Operation| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | | PerVector| | | | + //###########| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceGemmXdl< F16, F16, F16, F32, Col, Row, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 64, 16, 16, 4, 8, 16, 16, 1, 1, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, true, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, true, 7, 1, 1, LoopScheduler::Default, PipelineVersion::v1> + // clang-format on + >; + +void add_device_gemm_xdl_f16_f16_f16_km_kn_mn_irregular_default_pipeline_v1_instances( + OwnerList& instances) +{ + add_device_operation_instances(instances, Instances{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_f16_f16_f16/km_kn_mn_irregular_default_pipeline_v2_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_f16_f16_f16/km_kn_mn_irregular_default_pipeline_v2_instance.cpp new file mode 100644 index 0000000000000000000000000000000000000000..050ff6e231eede6fc5b83fa4a779007f44634115 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_f16_f16_f16/km_kn_mn_irregular_default_pipeline_v2_instance.cpp @@ -0,0 +1,34 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#include "common.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +// irregular tile size +using Instances = std::tuple< +// clang-format off +#if CK_EXPERIMENTAL_PIPELINE_V2_INSTANCES + // pipeline v2, 1 wave + //###########| AData| BData| CData| AccData| ALayout| BLayout| CLayout| A| B| C| GEMM| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CThreadTransfer| CThreadTransfer| NumPrefetch| LoopScheduler| Pipeline| + //###########| Type| Type| Type| Type| | | | Elementwise| Elementwise| Elementwise|Specialization| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| SrcDstVectorDim| DstScalar| | | | + //###########| | | | | | | | Operation| Operation| Operation| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | | PerVector| | | | + //###########| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceGemmXdl< F16, F16, F16, F32, Col, Row, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 64, 16, 16, 4, 8, 16, 16, 1, 1, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, true, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, true, 7, 1, 1, LoopScheduler::Default, PipelineVersion::v2> +#endif + // clang-format on + >; + +void add_device_gemm_xdl_f16_f16_f16_km_kn_mn_irregular_default_pipeline_v2_instances( + OwnerList& instances) +{ + add_device_operation_instances(instances, Instances{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_f16_f16_f16/km_kn_mn_irregular_interwave_pipeline_v1_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_f16_f16_f16/km_kn_mn_irregular_interwave_pipeline_v1_instance.cpp new file mode 100644 index 0000000000000000000000000000000000000000..0a0406baec7c7fe883ff1877132fff16a07ad3cd --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_f16_f16_f16/km_kn_mn_irregular_interwave_pipeline_v1_instance.cpp @@ -0,0 +1,34 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#include "common.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +// irregular tile size +using Instances = std::tuple< +// clang-format off +#if CK_EXPERIMENTAL_INTER_WAVE_INSTANCES + // pipeline v1, 2 waves + //###########| AData| BData| CData| AccData| ALayout| BLayout| CLayout| A| B| C| GEMM| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CThreadTransfer| CThreadTransfer| NumPrefetch| LoopScheduler| Pipeline| + //###########| Type| Type| Type| Type| | | | Elementwise| Elementwise| Elementwise|Specialization| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| SrcDstVectorDim| DstScalar| | | | + //###########| | | | | | | | Operation| Operation| Operation| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | | PerVector| | | | + //###########| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceGemmXdl< F16, F16, F16, F32, Col, Row, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 64, 16, 16, 4, 8, 16, 16, 1, 1, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, true, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, true, 7, 1, 1, LoopScheduler::Interwave, PipelineVersion::v1> +#endif + // clang-format on + >; + +void add_device_gemm_xdl_f16_f16_f16_km_kn_mn_irregular_interwave_pipeline_v1_instances( + OwnerList& instances) +{ + add_device_operation_instances(instances, Instances{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_f16_f16_f16/km_nk_mn_add_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_f16_f16_f16/km_nk_mn_add_instance.cpp new file mode 100644 index 0000000000000000000000000000000000000000..7fb1a4845db11028c35639cd0e78fbb387e195ea --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_f16_f16_f16/km_nk_mn_add_instance.cpp @@ -0,0 +1,38 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#include "common.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +using Instance = InstanceNN; +using Instances = OwnerList; + +void add_device_gemm_xdl_f16_f16_f16_km_nk_mn_default_pipeline_v1_instances(Instances&); +void add_device_gemm_xdl_f16_f16_f16_km_nk_mn_interwave_pipeline_v1_instances(Instances&); +void add_device_gemm_xdl_f16_f16_f16_km_nk_mn_default_pipeline_v2_instances(Instances&); +void add_device_gemm_xdl_f16_f16_f16_km_nk_mn_default_pipeline_v2_opt_instances(Instances&); + +void add_device_gemm_xdl_f16_f16_f16_km_nk_mn_irregular_default_pipeline_v1_instances(Instances&); +void add_device_gemm_xdl_f16_f16_f16_km_nk_mn_irregular_interwave_pipeline_v1_instances(Instances&); +void add_device_gemm_xdl_f16_f16_f16_km_nk_mn_irregular_default_pipeline_v2_instances(Instances&); + +void add_device_gemm_xdl_f16_f16_f16_km_nk_mn_instances(Instances& instances) +{ + add_device_gemm_xdl_f16_f16_f16_km_nk_mn_default_pipeline_v1_instances(instances); + add_device_gemm_xdl_f16_f16_f16_km_nk_mn_interwave_pipeline_v1_instances(instances); + add_device_gemm_xdl_f16_f16_f16_km_nk_mn_default_pipeline_v2_instances(instances); + add_device_gemm_xdl_f16_f16_f16_km_nk_mn_default_pipeline_v2_opt_instances(instances); + + add_device_gemm_xdl_f16_f16_f16_km_nk_mn_irregular_default_pipeline_v1_instances(instances); + add_device_gemm_xdl_f16_f16_f16_km_nk_mn_irregular_interwave_pipeline_v1_instances(instances); + add_device_gemm_xdl_f16_f16_f16_km_nk_mn_irregular_default_pipeline_v2_instances(instances); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_f16_f16_f16/km_nk_mn_default_pipeline_v1_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_f16_f16_f16/km_nk_mn_default_pipeline_v1_instance.cpp new file mode 100644 index 0000000000000000000000000000000000000000..a0dd60c0f56a7d985d9e65e5419753aefba775c6 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_f16_f16_f16/km_nk_mn_default_pipeline_v1_instance.cpp @@ -0,0 +1,40 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#include "common.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +// Compilation parameters for a[k, m] * b[n, k] = c[m, n] +using Instances = + std::tuple< + // clang-format off + // pipeline v1, 1 wave + //##########| AData| BData| CData| AccData| ALayout| BLayout| CLayout| A| B| C| GEMM| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CThreadTransfer| CThreadTransfer| NumPrefetch| LoopScheduler| Pipeline| + //##########| Type| Type| Type| Type| | | | Elementwise| Elementwise| Elementwise|Specialization| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| SrcDstVectorDim| DstScalar| | | | + //##########| | | | | | | | Operation| Operation| Operation| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | | PerVector| | | | + //##########| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceGemmXdl< F16, F16, F16, F32, Col, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 256, 128, 4, 8, 32, 32, 4, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1, 1, LoopScheduler::Default, PipelineVersion::v1>, + DeviceGemmXdl< F16, F16, F16, F32, Col, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 128, 256, 4, 8, 32, 32, 2, 4, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1, 1, LoopScheduler::Default, PipelineVersion::v1>, + DeviceGemmXdl< F16, F16, F16, F32, Col, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 128, 128, 128, 4, 8, 32, 32, 4, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1, 1, LoopScheduler::Default, PipelineVersion::v1>, + DeviceGemmXdl< F16, F16, F16, F32, Col, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 128, 128, 4, 8, 32, 32, 2, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1, 1, LoopScheduler::Default, PipelineVersion::v1>, + DeviceGemmXdl< F16, F16, F16, F32, Col, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 128, 128, 64, 4, 8, 32, 32, 2, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1, 1, LoopScheduler::Default, PipelineVersion::v1>, + DeviceGemmXdl< F16, F16, F16, F32, Col, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 128, 64, 128, 4, 8, 32, 32, 2, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1, 1, LoopScheduler::Default, PipelineVersion::v1>, + DeviceGemmXdl< F16, F16, F16, F32, Col, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 128, 64, 4, 8, 32, 32, 2, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1, 1, LoopScheduler::Default, PipelineVersion::v1>, + DeviceGemmXdl< F16, F16, F16, F32, Col, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 64, 128, 4, 8, 32, 32, 1, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1, 1, LoopScheduler::Default, PipelineVersion::v1> + // clang-format on + >; + +void add_device_gemm_xdl_f16_f16_f16_km_nk_mn_default_pipeline_v1_instances( + OwnerList& instances) +{ + add_device_operation_instances(instances, Instances{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_f16_f16_f16/km_nk_mn_default_pipeline_v2_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_f16_f16_f16/km_nk_mn_default_pipeline_v2_instance.cpp new file mode 100644 index 0000000000000000000000000000000000000000..122fff4960306d94839161f4c12b0828f95446ea --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_f16_f16_f16/km_nk_mn_default_pipeline_v2_instance.cpp @@ -0,0 +1,42 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#include "common.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +// Compilation parameters for a[k, m] * b[n, k] = c[m, n] +using Instances = + std::tuple< +// clang-format off +#if CK_EXPERIMENTAL_PIPELINE_V2_INSTANCES + // pipeline v2, 1 wave + //##########| AData| BData| CData| AccData| ALayout| BLayout| CLayout| A| B| C| GEMM| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CThreadTransfer| CThreadTransfer| NumPrefetch| LoopScheduler| Pipeline| + //##########| Type| Type| Type| Type| | | | Elementwise| Elementwise| Elementwise|Specialization| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| SrcDstVectorDim| DstScalar| | | | + //##########| | | | | | | | Operation| Operation| Operation| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | | PerVector| | | | + //##########| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceGemmXdl< F16, F16, F16, F32, Col, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 256, 128, 4, 8, 32, 32, 4, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1, 1, LoopScheduler::Default, PipelineVersion::v2>, + DeviceGemmXdl< F16, F16, F16, F32, Col, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 128, 256, 4, 8, 32, 32, 2, 4, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1, 1, LoopScheduler::Default, PipelineVersion::v2>, + DeviceGemmXdl< F16, F16, F16, F32, Col, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 128, 128, 128, 4, 8, 32, 32, 4, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1, 1, LoopScheduler::Default, PipelineVersion::v2>, + DeviceGemmXdl< F16, F16, F16, F32, Col, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 128, 128, 4, 8, 32, 32, 2, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1, 1, LoopScheduler::Default, PipelineVersion::v2>, + DeviceGemmXdl< F16, F16, F16, F32, Col, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 128, 128, 64, 4, 8, 32, 32, 2, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1, 1, LoopScheduler::Default, PipelineVersion::v2>, + DeviceGemmXdl< F16, F16, F16, F32, Col, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 128, 64, 128, 4, 8, 32, 32, 2, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1, 1, LoopScheduler::Default, PipelineVersion::v2>, + DeviceGemmXdl< F16, F16, F16, F32, Col, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 128, 64, 4, 8, 32, 32, 2, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1, 1, LoopScheduler::Default, PipelineVersion::v2>, + DeviceGemmXdl< F16, F16, F16, F32, Col, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 64, 128, 4, 8, 32, 32, 1, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1, 1, LoopScheduler::Default, PipelineVersion::v2> +#endif + // clang-format on + >; + +void add_device_gemm_xdl_f16_f16_f16_km_nk_mn_default_pipeline_v2_instances( + OwnerList& instances) +{ + add_device_operation_instances(instances, Instances{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_f16_f16_f16/km_nk_mn_default_pipeline_v2_opt_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_f16_f16_f16/km_nk_mn_default_pipeline_v2_opt_instance.cpp new file mode 100644 index 0000000000000000000000000000000000000000..9f459aabfc63ee4e71de5a7d47fdcaa0894937f4 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_f16_f16_f16/km_nk_mn_default_pipeline_v2_opt_instance.cpp @@ -0,0 +1,35 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#include "common.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +// Compilation parameters for a[k, m] * b[n, k] = c[m, n] +using Instances = + std::tuple< +// clang-format off +#if CK_EXPERIMENTAL_PIPELINE_V2_INSTANCES + // pipeline v2, 1 wave + //##########| AData| BData| CData| AccData| ALayout| BLayout| CLayout| A| B| C| GEMM| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CThreadTransfer| CThreadTransfer| NumPrefetch| LoopScheduler| Pipeline| + //##########| Type| Type| Type| Type| | | | Elementwise| Elementwise| Elementwise|Specialization| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| SrcDstVectorDim| DstScalar| | | | + //##########| | | | | | | | Operation| Operation| Operation| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | | PerVector| | | | + //##########| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceGemmXdl< F16, F16, F16, F32, Col, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 64, 128, 8, 8, 32, 32, 1, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, true, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1, 1, LoopScheduler::Default, PipelineVersion::v2> +#endif + // clang-format on + >; + +void add_device_gemm_xdl_f16_f16_f16_km_nk_mn_default_pipeline_v2_opt_instances( + OwnerList& instances) +{ + add_device_operation_instances(instances, Instances{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_f16_f16_f16/km_nk_mn_interwave_pipeline_v1_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_f16_f16_f16/km_nk_mn_interwave_pipeline_v1_instance.cpp new file mode 100644 index 0000000000000000000000000000000000000000..3671bea7a3fc8fd8240976d616bb63b34fd7f642 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_f16_f16_f16/km_nk_mn_interwave_pipeline_v1_instance.cpp @@ -0,0 +1,42 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#include "common.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +// Compilation parameters for a[k, m] * b[n, k] = c[m, n] +using Instances = + std::tuple< +// clang-format off +#if CK_EXPERIMENTAL_INTER_WAVE_INSTANCES + // pipeline v1, 2 waves + //##########| AData| BData| CData| AccData| ALayout| BLayout| CLayout| A| B| C| GEMM| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CThreadTransfer| CThreadTransfer| NumPrefetch| LoopScheduler| Pipeline| + //##########| Type| Type| Type| Type| | | | Elementwise| Elementwise| Elementwise|Specialization| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| SrcDstVectorDim| DstScalar| | | | + //##########| | | | | | | | Operation| Operation| Operation| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | | PerVector| | | | + //##########| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceGemmXdl< F16, F16, F16, F32, Col, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 256, 128, 4, 8, 32, 32, 4, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1, 1, LoopScheduler::Interwave, PipelineVersion::v1>, + DeviceGemmXdl< F16, F16, F16, F32, Col, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 128, 256, 4, 8, 32, 32, 2, 4, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1, 1, LoopScheduler::Interwave, PipelineVersion::v1>, + DeviceGemmXdl< F16, F16, F16, F32, Col, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 128, 128, 128, 4, 8, 32, 32, 4, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1, 1, LoopScheduler::Interwave, PipelineVersion::v1>, + DeviceGemmXdl< F16, F16, F16, F32, Col, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 128, 128, 4, 8, 32, 32, 2, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1, 1, LoopScheduler::Interwave, PipelineVersion::v1>, + DeviceGemmXdl< F16, F16, F16, F32, Col, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 128, 128, 64, 4, 8, 32, 32, 2, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1, 1, LoopScheduler::Interwave, PipelineVersion::v1>, + DeviceGemmXdl< F16, F16, F16, F32, Col, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 128, 64, 128, 4, 8, 32, 32, 2, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1, 1, LoopScheduler::Interwave, PipelineVersion::v1>, + DeviceGemmXdl< F16, F16, F16, F32, Col, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 128, 64, 4, 8, 32, 32, 2, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1, 1, LoopScheduler::Interwave, PipelineVersion::v1>, + DeviceGemmXdl< F16, F16, F16, F32, Col, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 64, 128, 4, 8, 32, 32, 1, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1, 1, LoopScheduler::Interwave, PipelineVersion::v1> +#endif + // clang-format on + >; + +void add_device_gemm_xdl_f16_f16_f16_km_nk_mn_interwave_pipeline_v1_instances( + OwnerList& instances) +{ + add_device_operation_instances(instances, Instances{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_f16_f16_f16/km_nk_mn_irregular_default_pipeline_v1_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_f16_f16_f16/km_nk_mn_irregular_default_pipeline_v1_instance.cpp new file mode 100644 index 0000000000000000000000000000000000000000..d7e7a0550515134f5a0f5f739ad731183cb6aa4b --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_f16_f16_f16/km_nk_mn_irregular_default_pipeline_v1_instance.cpp @@ -0,0 +1,32 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#include "common.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +// irregular tile size +using Instances = std::tuple< + // clang-format off + // pipeline v1, 1 wave + //###########| AData| BData| CData| AccData| ALayout| BLayout| CLayout| A| B| C| GEMM| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CThreadTransfer| CThreadTransfer| NumPrefetch| LoopScheduler| Pipeline| + //###########| Type| Type| Type| Type| | | | Elementwise| Elementwise| Elementwise|Specialization| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| SrcDstVectorDim| DstScalar| | | | + //###########| | | | | | | | Operation| Operation| Operation| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | | PerVector| | | | + //###########| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceGemmXdl< F16, F16, F16, F32, Col, Col, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 64, 16, 16, 4, 8, 16, 16, 1, 1, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 8, true, 7, 1, 1, LoopScheduler::Default, PipelineVersion::v1> + // clang-format on + >; + +void add_device_gemm_xdl_f16_f16_f16_km_nk_mn_irregular_default_pipeline_v1_instances( + OwnerList& instances) +{ + add_device_operation_instances(instances, Instances{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_f16_f16_f16/km_nk_mn_irregular_default_pipeline_v2_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_f16_f16_f16/km_nk_mn_irregular_default_pipeline_v2_instance.cpp new file mode 100644 index 0000000000000000000000000000000000000000..15e191e5a81c329eb98d698a9dd118f54459d61f --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_f16_f16_f16/km_nk_mn_irregular_default_pipeline_v2_instance.cpp @@ -0,0 +1,34 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#include "common.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +// irregular tile size +using Instances = std::tuple< +// clang-format off +#if CK_EXPERIMENTAL_PIPELINE_V2_INSTANCES + // pipeline v2, 1 wave + //###########| AData| BData| CData| AccData| ALayout| BLayout| CLayout| A| B| C| GEMM| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CThreadTransfer| CThreadTransfer| NumPrefetch| LoopScheduler| Pipeline| + //###########| Type| Type| Type| Type| | | | Elementwise| Elementwise| Elementwise|Specialization| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| SrcDstVectorDim| DstScalar| | | | + //###########| | | | | | | | Operation| Operation| Operation| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | | PerVector| | | | + //###########| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceGemmXdl< F16, F16, F16, F32, Col, Col, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 64, 16, 16, 4, 8, 16, 16, 1, 1, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 8, true, 7, 1, 1, LoopScheduler::Default, PipelineVersion::v2> +#endif + // clang-format on + >; + +void add_device_gemm_xdl_f16_f16_f16_km_nk_mn_irregular_default_pipeline_v2_instances( + OwnerList& instances) +{ + add_device_operation_instances(instances, Instances{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_f16_f16_f16/km_nk_mn_irregular_interwave_pipeline_v1_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_f16_f16_f16/km_nk_mn_irregular_interwave_pipeline_v1_instance.cpp new file mode 100644 index 0000000000000000000000000000000000000000..95fc8ecb46be3107cf47566526c67e2ddf4a35c2 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_f16_f16_f16/km_nk_mn_irregular_interwave_pipeline_v1_instance.cpp @@ -0,0 +1,34 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#include "common.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +// irregular tile size +using Instances = std::tuple< +// clang-format off +#if CK_EXPERIMENTAL_INTER_WAVE_INSTANCES + // pipeline v1, 2 waves + //###########| AData| BData| CData| AccData| ALayout| BLayout| CLayout| A| B| C| GEMM| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CThreadTransfer| CThreadTransfer| NumPrefetch| LoopScheduler| Pipeline| + //###########| Type| Type| Type| Type| | | | Elementwise| Elementwise| Elementwise|Specialization| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| SrcDstVectorDim| DstScalar| | | | + //###########| | | | | | | | Operation| Operation| Operation| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | | PerVector| | | | + //###########| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceGemmXdl< F16, F16, F16, F32, Col, Col, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 64, 16, 16, 4, 8, 16, 16, 1, 1, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 8, true, 7, 1, 1, LoopScheduler::Interwave, PipelineVersion::v1> +#endif + // clang-format on + >; + +void add_device_gemm_xdl_f16_f16_f16_km_nk_mn_irregular_interwave_pipeline_v1_instances( + OwnerList& instances) +{ + add_device_operation_instances(instances, Instances{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_f16_f16_f16/mk_kn_mn_add_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_f16_f16_f16/mk_kn_mn_add_instance.cpp new file mode 100644 index 0000000000000000000000000000000000000000..e3a7ccefb4d73fadad67f0e93b58f93ff3ead54f --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_f16_f16_f16/mk_kn_mn_add_instance.cpp @@ -0,0 +1,38 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#include "common.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +using Instance = InstanceTT; +using Instances = OwnerList; + +void add_device_gemm_xdl_f16_f16_f16_mk_kn_mn_default_pipeline_v1_instances(Instances&); +void add_device_gemm_xdl_f16_f16_f16_mk_kn_mn_interwave_pipeline_v1_instances(Instances&); +void add_device_gemm_xdl_f16_f16_f16_mk_kn_mn_default_pipeline_v2_instances(Instances&); +void add_device_gemm_xdl_f16_f16_f16_mk_kn_mn_default_pipeline_v2_opt_instances(Instances&); + +void add_device_gemm_xdl_f16_f16_f16_mk_kn_mn_irregular_default_pipeline_v1_instances(Instances&); +void add_device_gemm_xdl_f16_f16_f16_mk_kn_mn_irregular_interwave_pipeline_v1_instances(Instances&); +void add_device_gemm_xdl_f16_f16_f16_mk_kn_mn_irregular_default_pipeline_v2_instances(Instances&); + +void add_device_gemm_xdl_f16_f16_f16_mk_kn_mn_instances(Instances& instances) +{ + add_device_gemm_xdl_f16_f16_f16_mk_kn_mn_default_pipeline_v1_instances(instances); + add_device_gemm_xdl_f16_f16_f16_mk_kn_mn_interwave_pipeline_v1_instances(instances); + add_device_gemm_xdl_f16_f16_f16_mk_kn_mn_default_pipeline_v2_instances(instances); + add_device_gemm_xdl_f16_f16_f16_mk_kn_mn_default_pipeline_v2_opt_instances(instances); + + add_device_gemm_xdl_f16_f16_f16_mk_kn_mn_irregular_default_pipeline_v1_instances(instances); + add_device_gemm_xdl_f16_f16_f16_mk_kn_mn_irregular_interwave_pipeline_v1_instances(instances); + add_device_gemm_xdl_f16_f16_f16_mk_kn_mn_irregular_default_pipeline_v2_instances(instances); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_f16_f16_f16/mk_kn_mn_default_pipeline_v1_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_f16_f16_f16/mk_kn_mn_default_pipeline_v1_instance.cpp new file mode 100644 index 0000000000000000000000000000000000000000..98db8bad1ca01f32c78786e778ba302002259b1f --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_f16_f16_f16/mk_kn_mn_default_pipeline_v1_instance.cpp @@ -0,0 +1,49 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#include "common.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +// Compilation parameters for a[m, k] * b[k, n] = c[m, n] +using Instances = + std::tuple< + // clang-format off + // pipeline v1, 1 wave + //##########| AData| BData| CData| AccData| ALayout| BLayout| CLayout| A| B| C| GEMM| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CThreadTransfer| CThreadTransfer| NumPrefetch| LoopScheduler| Pipeline| + //##########| Type| Type| Type| Type| | | | Elementwise| Elementwise| Elementwise|Specialization| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| SrcDstVectorDim| DstScalar| | | | + //##########| | | | | | | | Operation| Operation| Operation| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | | PerVector| | | | + //##########| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceGemmXdl< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 256, 128, 4, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, 7, 1, 1, LoopScheduler::Default, PipelineVersion::v1>, + DeviceGemmXdl< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 128, 256, 4, 8, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, true, 7, 1, 1, LoopScheduler::Default, PipelineVersion::v1>, + DeviceGemmXdl< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 128, 128, 128, 4, 8, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, true, 7, 1, 1, LoopScheduler::Default, PipelineVersion::v1>, + DeviceGemmXdl< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 128, 128, 4, 8, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, 7, 1, 1, LoopScheduler::Default, PipelineVersion::v1>, + DeviceGemmXdl< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 128, 128, 64, 4, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, 7, 1, 1, LoopScheduler::Default, PipelineVersion::v1>, + DeviceGemmXdl< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 128, 64, 128, 4, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, true, 7, 1, 1, LoopScheduler::Default, PipelineVersion::v1>, + DeviceGemmXdl< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 128, 64, 4, 8, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, true, 7, 1, 1, LoopScheduler::Default, PipelineVersion::v1>, + DeviceGemmXdl< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 64, 128, 4, 8, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, 7, 1, 1, LoopScheduler::Default, PipelineVersion::v1>, + DeviceGemmXdl< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 128, 32, 256, 4, 8, 32, 32, 1, 4, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 8, true, 7, 1, 1, LoopScheduler::Default, PipelineVersion::v1>, + DeviceGemmXdl< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 128, 32, 128, 4, 8, 32, 32, 1, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, true, 7, 1, 1, LoopScheduler::Default, PipelineVersion::v1>, + DeviceGemmXdl< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 128, 32, 64, 4, 8, 32, 32, 1, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, 7, 1, 1, LoopScheduler::Default, PipelineVersion::v1>, + DeviceGemmXdl< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 64, 32, 32, 4, 8, 32, 32, 1, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, 7, 1, 1, LoopScheduler::Default, PipelineVersion::v1>, + DeviceGemmXdl< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 128, 16, 256, 4, 8, 16, 16, 1, 8, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 8, true, 7, 1, 1, LoopScheduler::Default, PipelineVersion::v1>, + DeviceGemmXdl< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 128, 16, 128, 4, 8, 16, 16, 1, 4, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, true, 7, 1, 1, LoopScheduler::Default, PipelineVersion::v1>, + DeviceGemmXdl< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 128, 16, 64, 4, 8, 16, 16, 1, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, 7, 1, 1, LoopScheduler::Default, PipelineVersion::v1>, + DeviceGemmXdl< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 128, 16, 32, 4, 8, 16, 16, 1, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, true, 7, 1, 1, LoopScheduler::Default, PipelineVersion::v1>, + DeviceGemmXdl< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 64, 16, 16, 4, 8, 16, 16, 1, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, true, 7, 1, 1, LoopScheduler::Default, PipelineVersion::v1> + // clang-format on + >; + +void add_device_gemm_xdl_f16_f16_f16_mk_kn_mn_default_pipeline_v1_instances( + OwnerList& instances) +{ + add_device_operation_instances(instances, Instances{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_f16_f16_f16/mk_kn_mn_default_pipeline_v2_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_f16_f16_f16/mk_kn_mn_default_pipeline_v2_instance.cpp new file mode 100644 index 0000000000000000000000000000000000000000..532c348b7e1571bb7d64b33ac29ab4d270dc1f22 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_f16_f16_f16/mk_kn_mn_default_pipeline_v2_instance.cpp @@ -0,0 +1,51 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#include "common.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +// Compilation parameters for a[m, k] * b[k, n] = c[m, n] +using Instances = + std::tuple< +// clang-format off +#if CK_EXPERIMENTAL_PIPELINE_V2_INSTANCES + // pipeline v2, 1 wave + //##########| AData| BData| CData| AccData| ALayout| BLayout| CLayout| A| B| C| GEMM| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CThreadTransfer| CThreadTransfer| NumPrefetch| LoopScheduler| Pipeline| + //##########| Type| Type| Type| Type| | | | Elementwise| Elementwise| Elementwise|Specialization| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| SrcDstVectorDim| DstScalar| | | | + //##########| | | | | | | | Operation| Operation| Operation| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | | PerVector| | | | + //##########| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceGemmXdl< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 256, 128, 4, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, 7, 1, 1, LoopScheduler::Default, PipelineVersion::v2>, + DeviceGemmXdl< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 128, 256, 4, 8, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, true, 7, 1, 1, LoopScheduler::Default, PipelineVersion::v2>, + DeviceGemmXdl< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 128, 128, 128, 4, 8, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, true, 7, 1, 1, LoopScheduler::Default, PipelineVersion::v2>, + DeviceGemmXdl< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 128, 128, 4, 8, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, 7, 1, 1, LoopScheduler::Default, PipelineVersion::v2>, + DeviceGemmXdl< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 128, 128, 64, 4, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, 7, 1, 1, LoopScheduler::Default, PipelineVersion::v2>, + DeviceGemmXdl< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 128, 64, 128, 4, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, true, 7, 1, 1, LoopScheduler::Default, PipelineVersion::v2>, + DeviceGemmXdl< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 128, 64, 4, 8, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, true, 7, 1, 1, LoopScheduler::Default, PipelineVersion::v2>, + DeviceGemmXdl< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 64, 128, 4, 8, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, 7, 1, 1, LoopScheduler::Default, PipelineVersion::v2>, + DeviceGemmXdl< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 128, 32, 256, 4, 8, 32, 32, 1, 4, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 8, true, 7, 1, 1, LoopScheduler::Default, PipelineVersion::v2>, + DeviceGemmXdl< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 128, 32, 128, 4, 8, 32, 32, 1, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, true, 7, 1, 1, LoopScheduler::Default, PipelineVersion::v2>, + DeviceGemmXdl< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 128, 32, 64, 4, 8, 32, 32, 1, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, 7, 1, 1, LoopScheduler::Default, PipelineVersion::v2>, + DeviceGemmXdl< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 64, 32, 32, 4, 8, 32, 32, 1, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, 7, 1, 1, LoopScheduler::Default, PipelineVersion::v2>, + DeviceGemmXdl< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 128, 16, 256, 4, 8, 16, 16, 1, 8, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 8, true, 7, 1, 1, LoopScheduler::Default, PipelineVersion::v2>, + DeviceGemmXdl< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 128, 16, 128, 4, 8, 16, 16, 1, 4, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, true, 7, 1, 1, LoopScheduler::Default, PipelineVersion::v2>, + DeviceGemmXdl< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 128, 16, 64, 4, 8, 16, 16, 1, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, 7, 1, 1, LoopScheduler::Default, PipelineVersion::v2>, + DeviceGemmXdl< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 128, 16, 32, 4, 8, 16, 16, 1, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, true, 7, 1, 1, LoopScheduler::Default, PipelineVersion::v2>, + DeviceGemmXdl< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 64, 16, 16, 4, 8, 16, 16, 1, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, true, 7, 1, 1, LoopScheduler::Default, PipelineVersion::v2> +#endif + // clang-format on + >; + +void add_device_gemm_xdl_f16_f16_f16_mk_kn_mn_default_pipeline_v2_instances( + OwnerList& instances) +{ + add_device_operation_instances(instances, Instances{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_f16_f16_f16/mk_kn_mn_default_pipeline_v2_opt_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_f16_f16_f16/mk_kn_mn_default_pipeline_v2_opt_instance.cpp new file mode 100644 index 0000000000000000000000000000000000000000..b931b8fdfdd41716b32ec40926576bb3a41ae9f0 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_f16_f16_f16/mk_kn_mn_default_pipeline_v2_opt_instance.cpp @@ -0,0 +1,35 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#include "common.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +// Compilation parameters for a[m, k] * b[k, n] = c[m, n] +using Instances = + std::tuple< +// clang-format off +#if CK_EXPERIMENTAL_PIPELINE_V2_INSTANCES + // pipeline v2, 1 wave + //##########| AData| BData| CData| AccData| ALayout| BLayout| CLayout| A| B| C| GEMM| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CThreadTransfer| CThreadTransfer| NumPrefetch| LoopScheduler| Pipeline| + //##########| Type| Type| Type| Type| | | | Elementwise| Elementwise| Elementwise|Specialization| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| SrcDstVectorDim| DstScalar| | | | + //##########| | | | | | | | Operation| Operation| Operation| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | | PerVector| | | | + //##########| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceGemmXdl< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 64, 128, 8, 8, 32, 32, 1, 2, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, 7, 1, 1, LoopScheduler::Default, PipelineVersion::v2> +#endif + // clang-format on + >; + +void add_device_gemm_xdl_f16_f16_f16_mk_kn_mn_default_pipeline_v2_opt_instances( + OwnerList& instances) +{ + add_device_operation_instances(instances, Instances{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_f16_f16_f16/mk_kn_mn_interwave_pipeline_v1_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_f16_f16_f16/mk_kn_mn_interwave_pipeline_v1_instance.cpp new file mode 100644 index 0000000000000000000000000000000000000000..fa53a3bf0f689376c9752c0cb32915408acb71d8 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_f16_f16_f16/mk_kn_mn_interwave_pipeline_v1_instance.cpp @@ -0,0 +1,51 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#include "common.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +// Compilation parameters for a[m, k] * b[k, n] = c[m, n] +using Instances = + std::tuple< +// clang-format off +#if CK_EXPERIMENTAL_INTER_WAVE_INSTANCES + // pipeline v1, 2 waves + //##########| AData| BData| CData| AccData| ALayout| BLayout| CLayout| A| B| C| GEMM| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CThreadTransfer| CThreadTransfer| NumPrefetch| LoopScheduler| Pipeline| + //##########| Type| Type| Type| Type| | | | Elementwise| Elementwise| Elementwise|Specialization| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| SrcDstVectorDim| DstScalar| | | | + //##########| | | | | | | | Operation| Operation| Operation| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | | PerVector| | | | + //##########| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceGemmXdl< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 256, 128, 4, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, 7, 1, 1, LoopScheduler::Interwave, PipelineVersion::v1>, + DeviceGemmXdl< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 128, 256, 4, 8, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, true, 7, 1, 1, LoopScheduler::Interwave, PipelineVersion::v1>, + DeviceGemmXdl< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 128, 128, 128, 4, 8, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, true, 7, 1, 1, LoopScheduler::Interwave, PipelineVersion::v1>, + DeviceGemmXdl< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 128, 128, 4, 8, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, 7, 1, 1, LoopScheduler::Interwave, PipelineVersion::v1>, + DeviceGemmXdl< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 128, 128, 64, 4, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, 7, 1, 1, LoopScheduler::Interwave, PipelineVersion::v1>, + DeviceGemmXdl< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 128, 64, 128, 4, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, true, 7, 1, 1, LoopScheduler::Interwave, PipelineVersion::v1>, + DeviceGemmXdl< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 128, 64, 4, 8, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, true, 7, 1, 1, LoopScheduler::Interwave, PipelineVersion::v1>, + DeviceGemmXdl< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 64, 128, 4, 8, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, 7, 1, 1, LoopScheduler::Interwave, PipelineVersion::v1>, + DeviceGemmXdl< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 128, 32, 256, 4, 8, 32, 32, 1, 4, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 8, true, 7, 1, 1, LoopScheduler::Interwave, PipelineVersion::v1>, + DeviceGemmXdl< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 128, 32, 128, 4, 8, 32, 32, 1, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, true, 7, 1, 1, LoopScheduler::Interwave, PipelineVersion::v1>, + DeviceGemmXdl< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 128, 32, 64, 4, 8, 32, 32, 1, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, 7, 1, 1, LoopScheduler::Interwave, PipelineVersion::v1>, + DeviceGemmXdl< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 64, 32, 32, 4, 8, 32, 32, 1, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, 7, 1, 1, LoopScheduler::Interwave, PipelineVersion::v1>, + DeviceGemmXdl< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 128, 16, 256, 4, 8, 16, 16, 1, 8, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 8, true, 7, 1, 1, LoopScheduler::Interwave, PipelineVersion::v1>, + DeviceGemmXdl< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 128, 16, 128, 4, 8, 16, 16, 1, 4, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, true, 7, 1, 1, LoopScheduler::Interwave, PipelineVersion::v1>, + DeviceGemmXdl< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 128, 16, 64, 4, 8, 16, 16, 1, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, 7, 1, 1, LoopScheduler::Interwave, PipelineVersion::v1>, + DeviceGemmXdl< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 128, 16, 32, 4, 8, 16, 16, 1, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, true, 7, 1, 1, LoopScheduler::Interwave, PipelineVersion::v1>, + DeviceGemmXdl< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 64, 16, 16, 4, 8, 16, 16, 1, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, true, 7, 1, 1, LoopScheduler::Interwave, PipelineVersion::v1> +#endif + // clang-format on + >; + +void add_device_gemm_xdl_f16_f16_f16_mk_kn_mn_interwave_pipeline_v1_instances( + OwnerList& instances) +{ + add_device_operation_instances(instances, Instances{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_f16_f16_f16/mk_kn_mn_irregular_default_pipeline_v1_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_f16_f16_f16/mk_kn_mn_irregular_default_pipeline_v1_instance.cpp new file mode 100644 index 0000000000000000000000000000000000000000..bc4b3e801bc52a64122f0cf14998b7a6ae2c047b --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_f16_f16_f16/mk_kn_mn_irregular_default_pipeline_v1_instance.cpp @@ -0,0 +1,32 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#include "common.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +// irregular tile size +using Instances = std::tuple< + // clang-format off + // pipeline v1, 1 wave + //###########| AData| BData| CData| AccData| ALayout| BLayout| CLayout| A| B| C| GEMM| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CThreadTransfer| CThreadTransfer| NumPrefetch| LoopScheduler| Pipeline| + //###########| Type| Type| Type| Type| | | | Elementwise| Elementwise| Elementwise|Specialization| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| SrcDstVectorDim| DstScalar| | | | + //###########| | | | | | | | Operation| Operation| Operation| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | | PerVector| | | | + //###########| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceGemmXdl< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 64, 16, 16, 4, 8, 16, 16, 1, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, true, 7, 1, 1, LoopScheduler::Default, PipelineVersion::v1> + // clang-format on + >; + +void add_device_gemm_xdl_f16_f16_f16_mk_kn_mn_irregular_default_pipeline_v1_instances( + OwnerList& instances) +{ + add_device_operation_instances(instances, Instances{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_f16_f16_f16/mk_kn_mn_irregular_default_pipeline_v2_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_f16_f16_f16/mk_kn_mn_irregular_default_pipeline_v2_instance.cpp new file mode 100644 index 0000000000000000000000000000000000000000..0c8e8b2c416f724166184889b794283155f1ee59 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_f16_f16_f16/mk_kn_mn_irregular_default_pipeline_v2_instance.cpp @@ -0,0 +1,34 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#include "common.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +// irregular tile size +using Instances = std::tuple< +// clang-format off +#if CK_EXPERIMENTAL_PIPELINE_V2_INSTANCES + // pipeline v2, 1 wave + //###########| AData| BData| CData| AccData| ALayout| BLayout| CLayout| A| B| C| GEMM| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CThreadTransfer| CThreadTransfer| NumPrefetch| LoopScheduler| Pipeline| + //###########| Type| Type| Type| Type| | | | Elementwise| Elementwise| Elementwise|Specialization| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| SrcDstVectorDim| DstScalar| | | | + //###########| | | | | | | | Operation| Operation| Operation| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | | PerVector| | | | + //###########| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceGemmXdl< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 64, 16, 16, 4, 8, 16, 16, 1, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, true, 7, 1, 1, LoopScheduler::Default, PipelineVersion::v2> +#endif + // clang-format on + >; + +void add_device_gemm_xdl_f16_f16_f16_mk_kn_mn_irregular_default_pipeline_v2_instances( + OwnerList& instances) +{ + add_device_operation_instances(instances, Instances{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_f16_f16_f16/mk_kn_mn_irregular_interwave_pipeline_v1_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_f16_f16_f16/mk_kn_mn_irregular_interwave_pipeline_v1_instance.cpp new file mode 100644 index 0000000000000000000000000000000000000000..c9d1913aec737f581990403b5b4ec5768c718978 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_f16_f16_f16/mk_kn_mn_irregular_interwave_pipeline_v1_instance.cpp @@ -0,0 +1,34 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#include "common.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +// irregular tile size +using Instances = std::tuple< +// clang-format off +#if CK_EXPERIMENTAL_INTER_WAVE_INSTANCES + // pipeline v1, 2 waves + //###########| AData| BData| CData| AccData| ALayout| BLayout| CLayout| A| B| C| GEMM| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CThreadTransfer| CThreadTransfer| NumPrefetch| LoopScheduler| Pipeline| + //###########| Type| Type| Type| Type| | | | Elementwise| Elementwise| Elementwise|Specialization| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| SrcDstVectorDim| DstScalar| | | | + //###########| | | | | | | | Operation| Operation| Operation| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | | PerVector| | | | + //###########| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceGemmXdl< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 64, 16, 16, 4, 8, 16, 16, 1, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, true, 7, 1, 1, LoopScheduler::Interwave, PipelineVersion::v1> +#endif + // clang-format on + >; + +void add_device_gemm_xdl_f16_f16_f16_mk_kn_mn_irregular_interwave_pipeline_v1_instances( + OwnerList& instances) +{ + add_device_operation_instances(instances, Instances{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_f16_f16_f16/mk_nk_mn_add_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_f16_f16_f16/mk_nk_mn_add_instance.cpp new file mode 100644 index 0000000000000000000000000000000000000000..70bfd0a80cb7f7ce9335396e2fbbdf6613f6b67d --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_f16_f16_f16/mk_nk_mn_add_instance.cpp @@ -0,0 +1,38 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#include "common.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +using Instance = InstanceTN; +using Instances = OwnerList; + +void add_device_gemm_xdl_f16_f16_f16_mk_nk_mn_default_pipeline_v1_instances(Instances&); +void add_device_gemm_xdl_f16_f16_f16_mk_nk_mn_interwave_pipeline_v1_instances(Instances&); +void add_device_gemm_xdl_f16_f16_f16_mk_nk_mn_default_pipeline_v2_instances(Instances&); +void add_device_gemm_xdl_f16_f16_f16_mk_nk_mn_default_pipeline_v2_opt_instances(Instances&); + +void add_device_gemm_xdl_f16_f16_f16_mk_nk_mn_irregular_default_pipeline_v1_instances(Instances&); +void add_device_gemm_xdl_f16_f16_f16_mk_nk_mn_irregular_interwave_pipeline_v1_instances(Instances&); +void add_device_gemm_xdl_f16_f16_f16_mk_nk_mn_irregular_default_pipeline_v2_instances(Instances&); + +void add_device_gemm_xdl_f16_f16_f16_mk_nk_mn_instances(Instances& instances) +{ + add_device_gemm_xdl_f16_f16_f16_mk_nk_mn_default_pipeline_v1_instances(instances); + add_device_gemm_xdl_f16_f16_f16_mk_nk_mn_interwave_pipeline_v1_instances(instances); + add_device_gemm_xdl_f16_f16_f16_mk_nk_mn_default_pipeline_v2_instances(instances); + add_device_gemm_xdl_f16_f16_f16_mk_nk_mn_default_pipeline_v2_opt_instances(instances); + + add_device_gemm_xdl_f16_f16_f16_mk_nk_mn_irregular_default_pipeline_v1_instances(instances); + add_device_gemm_xdl_f16_f16_f16_mk_nk_mn_irregular_interwave_pipeline_v1_instances(instances); + add_device_gemm_xdl_f16_f16_f16_mk_nk_mn_irregular_default_pipeline_v2_instances(instances); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_f16_f16_f16/mk_nk_mn_default_pipeline_v1_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_f16_f16_f16/mk_nk_mn_default_pipeline_v1_instance.cpp new file mode 100644 index 0000000000000000000000000000000000000000..0f4ccceb50a187daf14ec85a0ba5e59e877b0538 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_f16_f16_f16/mk_nk_mn_default_pipeline_v1_instance.cpp @@ -0,0 +1,44 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#include "common.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +// Compilation parameters for a[m, k] * b[n, k] = c[m, n] +using Instances = std::tuple< + // clang-format off + // pipeline v1, 1 wave + //###########| AData| BData| CData| AccData| ALayout| BLayout| CLayout| A| B| C| GEMM| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CThreadTransfer| CThreadTransfer| NumPrefetch| LoopScheduler| Pipeline| + //###########| Type| Type| Type| Type| | | | Elementwise| Elementwise| Elementwise|Specialization| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| SrcDstVectorDim| DstScalar| | | | + //###########| | | | | | | | Operation| Operation| Operation| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | | PerVector| | | | + //###########| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceGemmXdl< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 256, 128, 4, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1, 1, LoopScheduler::Default, PipelineVersion::v1>, + DeviceGemmXdl< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 128, 256, 4, 8, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1, 1, LoopScheduler::Default, PipelineVersion::v1>, + DeviceGemmXdl< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 128, 128, 128, 4, 8, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1, 1, LoopScheduler::Default, PipelineVersion::v1>, + DeviceGemmXdl< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 128, 128, 4, 8, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1, 1, LoopScheduler::Default, PipelineVersion::v1>, + DeviceGemmXdl< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 128, 128, 64, 4, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1, 1, LoopScheduler::Default, PipelineVersion::v1>, + DeviceGemmXdl< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 128, 64, 128, 4, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1, 1, LoopScheduler::Default, PipelineVersion::v1>, + DeviceGemmXdl< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 64, 64, 64, 4, 8, 32, 32, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1, 1, LoopScheduler::Default, PipelineVersion::v1>, + DeviceGemmXdl< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 128, 64, 4, 8, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1, 1, LoopScheduler::Default, PipelineVersion::v1>, + DeviceGemmXdl< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 64, 128, 4, 8, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1, 1, LoopScheduler::Default, PipelineVersion::v1>, + DeviceGemmXdl< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 128, 128, 32, 4, 8, 32, 32, 2, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1, 1, LoopScheduler::Default, PipelineVersion::v1>, + DeviceGemmXdl< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 128, 32, 128, 4, 8, 32, 32, 1, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1, 1, LoopScheduler::Default, PipelineVersion::v1>, + DeviceGemmXdl< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 64, 64, 32, 4, 8, 32, 32, 2, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1, 1, LoopScheduler::Default, PipelineVersion::v1>, + DeviceGemmXdl< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 64, 32, 64, 4, 8, 32, 32, 1, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1, 1, LoopScheduler::Default, PipelineVersion::v1> + // clang-format on + >; + +void add_device_gemm_xdl_f16_f16_f16_mk_nk_mn_default_pipeline_v1_instances( + OwnerList& instances) +{ + add_device_operation_instances(instances, Instances{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_f16_f16_f16/mk_nk_mn_default_pipeline_v2_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_f16_f16_f16/mk_nk_mn_default_pipeline_v2_instance.cpp new file mode 100644 index 0000000000000000000000000000000000000000..134755299d58073e28a1320ca26a08a2dd981c3f --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_f16_f16_f16/mk_nk_mn_default_pipeline_v2_instance.cpp @@ -0,0 +1,46 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#include "common.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +// Compilation parameters for a[m, k] * b[n, k] = c[m, n] +using Instances = std::tuple< +// clang-format off +#if CK_EXPERIMENTAL_PIPELINE_V2_INSTANCES + // pipeline v2, 1 wave + //###########| AData| BData| CData| AccData| ALayout| BLayout| CLayout| A| B| C| GEMM| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CThreadTransfer| CThreadTransfer| NumPrefetch| LoopScheduler| Pipeline| + //###########| Type| Type| Type| Type| | | | Elementwise| Elementwise| Elementwise|Specialization| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| SrcDstVectorDim| DstScalar| | | | + //###########| | | | | | | | Operation| Operation| Operation| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | | PerVector| | | | + //###########| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceGemmXdl< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 256, 128, 4, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1, 1, LoopScheduler::Default, PipelineVersion::v2>, + DeviceGemmXdl< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 128, 256, 4, 8, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1, 1, LoopScheduler::Default, PipelineVersion::v2>, + DeviceGemmXdl< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 128, 128, 128, 4, 8, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1, 1, LoopScheduler::Default, PipelineVersion::v2>, + DeviceGemmXdl< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 128, 128, 4, 8, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1, 1, LoopScheduler::Default, PipelineVersion::v2>, + DeviceGemmXdl< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 128, 128, 64, 4, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1, 1, LoopScheduler::Default, PipelineVersion::v2>, + DeviceGemmXdl< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 128, 64, 128, 4, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1, 1, LoopScheduler::Default, PipelineVersion::v2>, + DeviceGemmXdl< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 64, 64, 64, 4, 8, 32, 32, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1, 1, LoopScheduler::Default, PipelineVersion::v2>, + DeviceGemmXdl< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 128, 64, 4, 8, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1, 1, LoopScheduler::Default, PipelineVersion::v2>, + DeviceGemmXdl< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 64, 128, 4, 8, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1, 1, LoopScheduler::Default, PipelineVersion::v2>, + DeviceGemmXdl< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 128, 128, 32, 4, 8, 32, 32, 2, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1, 1, LoopScheduler::Default, PipelineVersion::v2>, + DeviceGemmXdl< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 128, 32, 128, 4, 8, 32, 32, 1, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1, 1, LoopScheduler::Default, PipelineVersion::v2>, + DeviceGemmXdl< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 64, 64, 32, 4, 8, 32, 32, 2, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1, 1, LoopScheduler::Default, PipelineVersion::v2>, + DeviceGemmXdl< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 64, 32, 64, 4, 8, 32, 32, 1, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1, 1, LoopScheduler::Default, PipelineVersion::v2> +#endif + // clang-format on + >; + +void add_device_gemm_xdl_f16_f16_f16_mk_nk_mn_default_pipeline_v2_instances( + OwnerList& instances) +{ + add_device_operation_instances(instances, Instances{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_f16_f16_f16/mk_nk_mn_default_pipeline_v2_opt_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_f16_f16_f16/mk_nk_mn_default_pipeline_v2_opt_instance.cpp new file mode 100644 index 0000000000000000000000000000000000000000..76d71ee97a4e03138d757b24a7f593b4b2e7a945 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_f16_f16_f16/mk_nk_mn_default_pipeline_v2_opt_instance.cpp @@ -0,0 +1,34 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#include "common.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +// Compilation parameters for a[m, k] * b[n, k] = c[m, n] +using Instances = std::tuple< +// clang-format off +#if CK_EXPERIMENTAL_PIPELINE_V2_INSTANCES + // pipeline v2, 1 wave + //###########| AData| BData| CData| AccData| ALayout| BLayout| CLayout| A| B| C| GEMM| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CThreadTransfer| CThreadTransfer| NumPrefetch| LoopScheduler| Pipeline| + //###########| Type| Type| Type| Type| | | | Elementwise| Elementwise| Elementwise|Specialization| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| SrcDstVectorDim| DstScalar| | | | + //###########| | | | | | | | Operation| Operation| Operation| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | | PerVector| | | | + //###########| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceGemmXdl< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 64, 128, 8, 8, 32, 32, 1, 2, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1, 1, LoopScheduler::Default, PipelineVersion::v2> +#endif + // clang-format on + >; + +void add_device_gemm_xdl_f16_f16_f16_mk_nk_mn_default_pipeline_v2_opt_instances( + OwnerList& instances) +{ + add_device_operation_instances(instances, Instances{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_f16_f16_f16/mk_nk_mn_interwave_pipeline_v1_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_f16_f16_f16/mk_nk_mn_interwave_pipeline_v1_instance.cpp new file mode 100644 index 0000000000000000000000000000000000000000..0410eabb700eeed4ae642d2a8f788df2b0ea9467 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_f16_f16_f16/mk_nk_mn_interwave_pipeline_v1_instance.cpp @@ -0,0 +1,46 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#include "common.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +// Compilation parameters for a[m, k] * b[n, k] = c[m, n] +using Instances = std::tuple< +// clang-format off +#if CK_EXPERIMENTAL_INTER_WAVE_INSTANCES + // pipeline v1, 2 waves + //###########| AData| BData| CData| AccData| ALayout| BLayout| CLayout| A| B| C| GEMM| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CThreadTransfer| CThreadTransfer| NumPrefetch| LoopScheduler| Pipeline| + //###########| Type| Type| Type| Type| | | | Elementwise| Elementwise| Elementwise|Specialization| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| SrcDstVectorDim| DstScalar| | | | + //###########| | | | | | | | Operation| Operation| Operation| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | | PerVector| | | | + //###########| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceGemmXdl< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 256, 128, 4, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1, 1, LoopScheduler::Interwave, PipelineVersion::v1>, + DeviceGemmXdl< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 128, 256, 4, 8, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1, 1, LoopScheduler::Interwave, PipelineVersion::v1>, + DeviceGemmXdl< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 128, 128, 128, 4, 8, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1, 1, LoopScheduler::Interwave, PipelineVersion::v1>, + DeviceGemmXdl< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 128, 128, 4, 8, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1, 1, LoopScheduler::Interwave, PipelineVersion::v1>, + DeviceGemmXdl< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 128, 128, 64, 4, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1, 1, LoopScheduler::Interwave, PipelineVersion::v1>, + DeviceGemmXdl< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 128, 64, 128, 4, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1, 1, LoopScheduler::Interwave, PipelineVersion::v1>, + DeviceGemmXdl< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 64, 64, 64, 4, 8, 32, 32, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1, 1, LoopScheduler::Interwave, PipelineVersion::v1>, + DeviceGemmXdl< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 128, 64, 4, 8, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1, 1, LoopScheduler::Interwave, PipelineVersion::v1>, + DeviceGemmXdl< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 64, 128, 4, 8, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1, 1, LoopScheduler::Interwave, PipelineVersion::v1>, + DeviceGemmXdl< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 128, 128, 32, 4, 8, 32, 32, 2, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1, 1, LoopScheduler::Interwave, PipelineVersion::v1>, + DeviceGemmXdl< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 128, 32, 128, 4, 8, 32, 32, 1, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1, 1, LoopScheduler::Interwave, PipelineVersion::v1>, + DeviceGemmXdl< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 64, 64, 32, 4, 8, 32, 32, 2, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1, 1, LoopScheduler::Interwave, PipelineVersion::v1>, + DeviceGemmXdl< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 64, 32, 64, 4, 8, 32, 32, 1, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1, 1, LoopScheduler::Interwave, PipelineVersion::v1> +#endif + // clang-format on + >; + +void add_device_gemm_xdl_f16_f16_f16_mk_nk_mn_interwave_pipeline_v1_instances( + OwnerList& instances) +{ + add_device_operation_instances(instances, Instances{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_f16_f16_f16/mk_nk_mn_irregular_default_pipeline_v1_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_f16_f16_f16/mk_nk_mn_irregular_default_pipeline_v1_instance.cpp new file mode 100644 index 0000000000000000000000000000000000000000..e6e73cdccaf9095aac0b5e88915416df41a9f1bb --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_f16_f16_f16/mk_nk_mn_irregular_default_pipeline_v1_instance.cpp @@ -0,0 +1,34 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#include "common.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +// irregular tile size +using Instances = std::tuple< + // clang-format off + // pipeline v1, 1 wave + //###########| AData| BData| CData| AccData| ALayout| BLayout| CLayout| A| B| C| GEMM| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CThreadTransfer| CThreadTransfer| NumPrefetch| LoopScheduler| Pipeline| + //###########| Type| Type| Type| Type| | | | Elementwise| Elementwise| Elementwise|Specialization| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| SrcDstVectorDim| DstScalar| | | | + //###########| | | | | | | | Operation| Operation| Operation| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | | PerVector| | | | + //###########| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceGemmXdl< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 64, 64, 32, 4, 8, 32, 32, 2, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1, 1, LoopScheduler::Default, PipelineVersion::v1>, + DeviceGemmXdl< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 256, 128, 144, 8, 8, 16, 16, 2, 9, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<8, 8, 4>, S<1, 0, 2>, S<1, 0, 2>, 2, 2, 2, true, 7, 1, 1, LoopScheduler::Default, PipelineVersion::v1>, + DeviceGemmXdl< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 256, 128, 144, 4, 8, 16, 16, 2, 9, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 4>, S<1, 0, 2>, S<1, 0, 2>, 2, 2, 2, true, 7, 1, 1, LoopScheduler::Default, PipelineVersion::v1> + // clang-format on + >; + +void add_device_gemm_xdl_f16_f16_f16_mk_nk_mn_irregular_default_pipeline_v1_instances( + OwnerList& instances) +{ + add_device_operation_instances(instances, Instances{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_f16_f16_f16/mk_nk_mn_irregular_default_pipeline_v2_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_f16_f16_f16/mk_nk_mn_irregular_default_pipeline_v2_instance.cpp new file mode 100644 index 0000000000000000000000000000000000000000..78f18fd5821450a5cf701feb7ebaaca0d443579d --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_f16_f16_f16/mk_nk_mn_irregular_default_pipeline_v2_instance.cpp @@ -0,0 +1,36 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#include "common.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +// irregular tile size +using Instances = std::tuple< +// clang-format off +#if CK_EXPERIMENTAL_PIPELINE_V2_INSTANCES + // pipeline v2, 1 wave + //###########| AData| BData| CData| AccData| ALayout| BLayout| CLayout| A| B| C| GEMM| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CThreadTransfer| CThreadTransfer| NumPrefetch| LoopScheduler| Pipeline| + //###########| Type| Type| Type| Type| | | | Elementwise| Elementwise| Elementwise|Specialization| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| SrcDstVectorDim| DstScalar| | | | + //###########| | | | | | | | Operation| Operation| Operation| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | | PerVector| | | | + //###########| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceGemmXdl< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 64, 64, 32, 4, 8, 32, 32, 2, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1, 1, LoopScheduler::Default, PipelineVersion::v2>, + DeviceGemmXdl< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 256, 128, 144, 8, 8, 16, 16, 2, 9, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<8, 8, 4>, S<1, 0, 2>, S<1, 0, 2>, 2, 2, 2, true, 7, 1, 1, LoopScheduler::Default, PipelineVersion::v2>, + DeviceGemmXdl< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 256, 128, 144, 4, 8, 16, 16, 2, 9, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 4>, S<1, 0, 2>, S<1, 0, 2>, 2, 2, 2, true, 7, 1, 1, LoopScheduler::Default, PipelineVersion::v2> +#endif + // clang-format on + >; + +void add_device_gemm_xdl_f16_f16_f16_mk_nk_mn_irregular_default_pipeline_v2_instances( + OwnerList& instances) +{ + add_device_operation_instances(instances, Instances{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_f16_f16_f16/mk_nk_mn_irregular_interwave_pipeline_v1_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_f16_f16_f16/mk_nk_mn_irregular_interwave_pipeline_v1_instance.cpp new file mode 100644 index 0000000000000000000000000000000000000000..a41919aab7c84d9a2436445869baf47c3ec5c531 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_f16_f16_f16/mk_nk_mn_irregular_interwave_pipeline_v1_instance.cpp @@ -0,0 +1,36 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#include "common.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +// irregular tile size +using Instances = std::tuple< +// clang-format off +#if CK_EXPERIMENTAL_INTER_WAVE_INSTANCES + // pipeline v1, 2 waves + //###########| AData| BData| CData| AccData| ALayout| BLayout| CLayout| A| B| C| GEMM| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CThreadTransfer| CThreadTransfer| NumPrefetch| LoopScheduler| Pipeline| + //###########| Type| Type| Type| Type| | | | Elementwise| Elementwise| Elementwise|Specialization| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| SrcDstVectorDim| DstScalar| | | | + //###########| | | | | | | | Operation| Operation| Operation| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | | PerVector| | | | + //###########| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceGemmXdl< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 64, 64, 32, 4, 8, 32, 32, 2, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1, 1, LoopScheduler::Interwave, PipelineVersion::v1>, + DeviceGemmXdl< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 256, 128, 144, 8, 8, 16, 16, 2, 9, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<8, 8, 4>, S<1, 0, 2>, S<1, 0, 2>, 2, 2, 2, true, 7, 1, 1, LoopScheduler::Interwave, PipelineVersion::v1>, + DeviceGemmXdl< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 256, 128, 144, 4, 8, 16, 16, 2, 9, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 4>, S<1, 0, 2>, S<1, 0, 2>, 2, 2, 2, true, 7, 1, 1, LoopScheduler::Interwave, PipelineVersion::v1> +#endif + // clang-format on + >; + +void add_device_gemm_xdl_f16_f16_f16_mk_nk_mn_irregular_interwave_pipeline_v1_instances( + OwnerList& instances) +{ + add_device_operation_instances(instances, Instances{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_f16_f16_f16_km_kn_mn_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_f16_f16_f16_km_kn_mn_instance.cpp deleted file mode 100644 index 18a78674e7ac4d79e6da3b035c8e9ea966618f16..0000000000000000000000000000000000000000 --- a/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_f16_f16_f16_km_kn_mn_instance.cpp +++ /dev/null @@ -1,110 +0,0 @@ -// SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. - -#include - -#include "ck/ck.hpp" -#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" -#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" -#include "ck/tensor_operation/gpu/device/impl/device_gemm_xdl.hpp" -#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" - -namespace ck { -namespace tensor_operation { -namespace device { -namespace instance { - -using F16 = ck::half_t; -using F32 = float; - -using Row = ck::tensor_layout::gemm::RowMajor; -using Col = ck::tensor_layout::gemm::ColumnMajor; - -template -using S = ck::Sequence; - -using PassThrough = ck::tensor_operation::element_wise::PassThrough; - -static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default; -static constexpr auto GemmMNPadding = ck::tensor_operation::device::GemmSpecialization::MNPadding; - -// Compilation parameters for a[k, m] * b[k, n] = c[m, n] -using device_gemm_xdl_f16_f16_f16_km_kn_mn_instances = - std::tuple< - // clang-format off - //##########| AData| BData| CData| AccData| ALayout| BLayout| CLayout| A| B| C| GEMM| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CThreadTransfer| CThreadTransfer| NumPrefetch| LoopScheduler| Pipeline| - //##########| Type| Type| Type| Type| | | | Elementwise| Elementwise| Elementwise|Specialization| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| SrcDstVectorDim| DstScalar| | | | - //##########| | | | | | | | Operation| Operation| Operation| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | | PerVector| | | | - //##########| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | - // pipeline v1, 1 wave - DeviceGemmXdl< F16, F16, F16, F32, Col, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 256, 128, 4, 8, 32, 32, 4, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, 7, 1, 1, LoopScheduler::Default, PipelineVersion::v1>, - DeviceGemmXdl< F16, F16, F16, F32, Col, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 128, 256, 4, 8, 32, 32, 2, 4, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, true, 7, 1, 1, LoopScheduler::Default, PipelineVersion::v1>, - DeviceGemmXdl< F16, F16, F16, F32, Col, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 128, 128, 128, 4, 8, 32, 32, 4, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, true, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, true, 7, 1, 1, LoopScheduler::Default, PipelineVersion::v1>, - DeviceGemmXdl< F16, F16, F16, F32, Col, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 128, 128, 4, 8, 32, 32, 2, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, 7, 1, 1, LoopScheduler::Default, PipelineVersion::v1>, - DeviceGemmXdl< F16, F16, F16, F32, Col, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 128, 128, 64, 4, 8, 32, 32, 2, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, true, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, 7, 1, 1, LoopScheduler::Default, PipelineVersion::v1>, - DeviceGemmXdl< F16, F16, F16, F32, Col, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 128, 64, 128, 4, 8, 32, 32, 2, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, true, 7, 1, 1, LoopScheduler::Default, PipelineVersion::v1>, - DeviceGemmXdl< F16, F16, F16, F32, Col, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 128, 64, 4, 8, 32, 32, 2, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, true, 7, 1, 1, LoopScheduler::Default, PipelineVersion::v1>, - DeviceGemmXdl< F16, F16, F16, F32, Col, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 64, 128, 4, 8, 32, 32, 1, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, 7, 1, 1, LoopScheduler::Default, PipelineVersion::v1> -#if CK_EXPERIMENTAL_INTER_WAVE_INSTANCES - // pipeline v1, 2 waves - , - DeviceGemmXdl< F16, F16, F16, F32, Col, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 256, 128, 4, 8, 32, 32, 4, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, 7, 1, 1, LoopScheduler::Interwave, PipelineVersion::v1>, - DeviceGemmXdl< F16, F16, F16, F32, Col, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 128, 256, 4, 8, 32, 32, 2, 4, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, true, 7, 1, 1, LoopScheduler::Interwave, PipelineVersion::v1>, - DeviceGemmXdl< F16, F16, F16, F32, Col, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 128, 128, 128, 4, 8, 32, 32, 4, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, true, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, true, 7, 1, 1, LoopScheduler::Interwave, PipelineVersion::v1>, - DeviceGemmXdl< F16, F16, F16, F32, Col, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 128, 128, 4, 8, 32, 32, 2, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, 7, 1, 1, LoopScheduler::Interwave, PipelineVersion::v1>, - DeviceGemmXdl< F16, F16, F16, F32, Col, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 128, 128, 64, 4, 8, 32, 32, 2, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, true, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, 7, 1, 1, LoopScheduler::Interwave, PipelineVersion::v1>, - DeviceGemmXdl< F16, F16, F16, F32, Col, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 128, 64, 128, 4, 8, 32, 32, 2, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, true, 7, 1, 1, LoopScheduler::Interwave, PipelineVersion::v1>, - DeviceGemmXdl< F16, F16, F16, F32, Col, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 128, 64, 4, 8, 32, 32, 2, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, true, 7, 1, 1, LoopScheduler::Interwave, PipelineVersion::v1>, - DeviceGemmXdl< F16, F16, F16, F32, Col, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 64, 128, 4, 8, 32, 32, 1, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, 7, 1, 1, LoopScheduler::Interwave, PipelineVersion::v1> -#endif -#if CK_EXPERIMENTAL_PIPELINE_V2_INSTANCES - // pipeline v2, 1 wave - , - DeviceGemmXdl< F16, F16, F16, F32, Col, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 256, 128, 4, 8, 32, 32, 4, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, 7, 1, 1, LoopScheduler::Default, PipelineVersion::v2>, - DeviceGemmXdl< F16, F16, F16, F32, Col, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 128, 256, 4, 8, 32, 32, 2, 4, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, true, 7, 1, 1, LoopScheduler::Default, PipelineVersion::v2>, - DeviceGemmXdl< F16, F16, F16, F32, Col, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 128, 128, 128, 4, 8, 32, 32, 4, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, true, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, true, 7, 1, 1, LoopScheduler::Default, PipelineVersion::v2>, - DeviceGemmXdl< F16, F16, F16, F32, Col, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 128, 128, 4, 8, 32, 32, 2, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, 7, 1, 1, LoopScheduler::Default, PipelineVersion::v2>, - DeviceGemmXdl< F16, F16, F16, F32, Col, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 128, 128, 64, 4, 8, 32, 32, 2, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, true, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, 7, 1, 1, LoopScheduler::Default, PipelineVersion::v2>, - DeviceGemmXdl< F16, F16, F16, F32, Col, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 128, 64, 128, 4, 8, 32, 32, 2, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, true, 7, 1, 1, LoopScheduler::Default, PipelineVersion::v2>, - DeviceGemmXdl< F16, F16, F16, F32, Col, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 128, 64, 4, 8, 32, 32, 2, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, true, 7, 1, 1, LoopScheduler::Default, PipelineVersion::v2>, - DeviceGemmXdl< F16, F16, F16, F32, Col, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 64, 128, 4, 8, 32, 32, 1, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, 7, 1, 1, LoopScheduler::Default, PipelineVersion::v2> -#endif - // clang-format on - >; - -// irregular tile size -using device_gemm_xdl_f16_f16_f16_km_kn_mn_irregular_tile_instances = std::tuple< - // clang-format off - //###########| AData| BData| CData| AccData| ALayout| BLayout| CLayout| A| B| C| GEMM| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CThreadTransfer| CThreadTransfer| NumPrefetch| LoopScheduler| Pipeline| - //###########| Type| Type| Type| Type| | | | Elementwise| Elementwise| Elementwise|Specialization| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| SrcDstVectorDim| DstScalar| | | | - //###########| | | | | | | | Operation| Operation| Operation| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | | PerVector| | | | - //###########| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | - // pipeline v1, 1 wave - DeviceGemmXdl< F16, F16, F16, F32, Col, Row, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 64, 16, 16, 4, 8, 16, 16, 1, 1, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, true, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, true, 7, 1, 1, LoopScheduler::Default, PipelineVersion::v1> -#if CK_EXPERIMENTAL_INTER_WAVE_INSTANCES - // pipeline v1, 2 waves - , - DeviceGemmXdl< F16, F16, F16, F32, Col, Row, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 64, 16, 16, 4, 8, 16, 16, 1, 1, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, true, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, true, 7, 1, 1, LoopScheduler::Interwave, PipelineVersion::v1> -#endif -#if CK_EXPERIMENTAL_PIPELINE_V2_INSTANCES - // pipeline v2, 1 wave - , - DeviceGemmXdl< F16, F16, F16, F32, Col, Row, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 64, 16, 16, 4, 8, 16, 16, 1, 1, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, true, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, true, 7, 1, 1, LoopScheduler::Default, PipelineVersion::v2> -#endif - // clang-format on - >; - -void add_device_gemm_xdl_f16_f16_f16_km_kn_mn_instances( - std::vector>>& - instances) -{ - add_device_operation_instances(instances, device_gemm_xdl_f16_f16_f16_km_kn_mn_instances{}); - add_device_operation_instances(instances, - device_gemm_xdl_f16_f16_f16_km_kn_mn_irregular_tile_instances{}); -} - -} // namespace instance -} // namespace device -} // namespace tensor_operation -} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_f16_f16_f16_km_nk_mn_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_f16_f16_f16_km_nk_mn_instance.cpp deleted file mode 100644 index cef6070af8a72a6cdb0616574736e722094f3a74..0000000000000000000000000000000000000000 --- a/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_f16_f16_f16_km_nk_mn_instance.cpp +++ /dev/null @@ -1,110 +0,0 @@ -// SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. - -#include - -#include "ck/ck.hpp" -#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" -#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" -#include "ck/tensor_operation/gpu/device/impl/device_gemm_xdl.hpp" -#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" - -namespace ck { -namespace tensor_operation { -namespace device { -namespace instance { - -using F16 = ck::half_t; -using F32 = float; - -using Row = ck::tensor_layout::gemm::RowMajor; -using Col = ck::tensor_layout::gemm::ColumnMajor; - -template -using S = ck::Sequence; - -using PassThrough = ck::tensor_operation::element_wise::PassThrough; - -static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default; -static constexpr auto GemmMNPadding = ck::tensor_operation::device::GemmSpecialization::MNPadding; - -// Compilation parameters for a[k, m] * b[n, k] = c[m, n] -using device_gemm_xdl_f16_f16_f16_km_nk_mn_instances = - std::tuple< - // clang-format off - //##########| AData| BData| CData| AccData| ALayout| BLayout| CLayout| A| B| C| GEMM| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CThreadTransfer| CThreadTransfer| NumPrefetch| LoopScheduler| Pipeline| - //##########| Type| Type| Type| Type| | | | Elementwise| Elementwise| Elementwise|Specialization| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| SrcDstVectorDim| DstScalar| | | | - //##########| | | | | | | | Operation| Operation| Operation| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | | PerVector| | | | - //##########| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | - // pipeline v1, 1 wave - DeviceGemmXdl< F16, F16, F16, F32, Col, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 256, 128, 4, 8, 32, 32, 4, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1, 1, LoopScheduler::Default, PipelineVersion::v1>, - DeviceGemmXdl< F16, F16, F16, F32, Col, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 128, 256, 4, 8, 32, 32, 2, 4, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1, 1, LoopScheduler::Default, PipelineVersion::v1>, - DeviceGemmXdl< F16, F16, F16, F32, Col, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 128, 128, 128, 4, 8, 32, 32, 4, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1, 1, LoopScheduler::Default, PipelineVersion::v1>, - DeviceGemmXdl< F16, F16, F16, F32, Col, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 128, 128, 4, 8, 32, 32, 2, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1, 1, LoopScheduler::Default, PipelineVersion::v1>, - DeviceGemmXdl< F16, F16, F16, F32, Col, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 128, 128, 64, 4, 8, 32, 32, 2, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1, 1, LoopScheduler::Default, PipelineVersion::v1>, - DeviceGemmXdl< F16, F16, F16, F32, Col, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 128, 64, 128, 4, 8, 32, 32, 2, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1, 1, LoopScheduler::Default, PipelineVersion::v1>, - DeviceGemmXdl< F16, F16, F16, F32, Col, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 128, 64, 4, 8, 32, 32, 2, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1, 1, LoopScheduler::Default, PipelineVersion::v1>, - DeviceGemmXdl< F16, F16, F16, F32, Col, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 64, 128, 4, 8, 32, 32, 1, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1, 1, LoopScheduler::Default, PipelineVersion::v1> -#if CK_EXPERIMENTAL_INTER_WAVE_INSTANCES - // pipeline v1, 2 waves - , - DeviceGemmXdl< F16, F16, F16, F32, Col, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 256, 128, 4, 8, 32, 32, 4, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1, 1, LoopScheduler::Interwave, PipelineVersion::v1>, - DeviceGemmXdl< F16, F16, F16, F32, Col, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 128, 256, 4, 8, 32, 32, 2, 4, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1, 1, LoopScheduler::Interwave, PipelineVersion::v1>, - DeviceGemmXdl< F16, F16, F16, F32, Col, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 128, 128, 128, 4, 8, 32, 32, 4, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1, 1, LoopScheduler::Interwave, PipelineVersion::v1>, - DeviceGemmXdl< F16, F16, F16, F32, Col, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 128, 128, 4, 8, 32, 32, 2, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1, 1, LoopScheduler::Interwave, PipelineVersion::v1>, - DeviceGemmXdl< F16, F16, F16, F32, Col, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 128, 128, 64, 4, 8, 32, 32, 2, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1, 1, LoopScheduler::Interwave, PipelineVersion::v1>, - DeviceGemmXdl< F16, F16, F16, F32, Col, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 128, 64, 128, 4, 8, 32, 32, 2, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1, 1, LoopScheduler::Interwave, PipelineVersion::v1>, - DeviceGemmXdl< F16, F16, F16, F32, Col, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 128, 64, 4, 8, 32, 32, 2, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1, 1, LoopScheduler::Interwave, PipelineVersion::v1>, - DeviceGemmXdl< F16, F16, F16, F32, Col, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 64, 128, 4, 8, 32, 32, 1, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1, 1, LoopScheduler::Interwave, PipelineVersion::v1> -#endif -#if CK_EXPERIMENTAL_PIPELINE_V2_INSTANCES - // pipeline v2, 1 wave - , - DeviceGemmXdl< F16, F16, F16, F32, Col, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 256, 128, 4, 8, 32, 32, 4, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1, 1, LoopScheduler::Default, PipelineVersion::v2>, - DeviceGemmXdl< F16, F16, F16, F32, Col, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 128, 256, 4, 8, 32, 32, 2, 4, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1, 1, LoopScheduler::Default, PipelineVersion::v2>, - DeviceGemmXdl< F16, F16, F16, F32, Col, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 128, 128, 128, 4, 8, 32, 32, 4, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1, 1, LoopScheduler::Default, PipelineVersion::v2>, - DeviceGemmXdl< F16, F16, F16, F32, Col, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 128, 128, 4, 8, 32, 32, 2, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1, 1, LoopScheduler::Default, PipelineVersion::v2>, - DeviceGemmXdl< F16, F16, F16, F32, Col, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 128, 128, 64, 4, 8, 32, 32, 2, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1, 1, LoopScheduler::Default, PipelineVersion::v2>, - DeviceGemmXdl< F16, F16, F16, F32, Col, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 128, 64, 128, 4, 8, 32, 32, 2, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1, 1, LoopScheduler::Default, PipelineVersion::v2>, - DeviceGemmXdl< F16, F16, F16, F32, Col, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 128, 64, 4, 8, 32, 32, 2, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1, 1, LoopScheduler::Default, PipelineVersion::v2>, - DeviceGemmXdl< F16, F16, F16, F32, Col, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 64, 128, 4, 8, 32, 32, 1, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1, 1, LoopScheduler::Default, PipelineVersion::v2> -#endif - // clang-format on - >; - -// irregular tile size -using device_gemm_xdl_f16_f16_f16_km_nk_mn_irregular_tile_instances = std::tuple< - // clang-format off - //###########| AData| BData| CData| AccData| ALayout| BLayout| CLayout| A| B| C| GEMM| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CThreadTransfer| CThreadTransfer| NumPrefetch| LoopScheduler| Pipeline| - //###########| Type| Type| Type| Type| | | | Elementwise| Elementwise| Elementwise|Specialization| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| SrcDstVectorDim| DstScalar| | | | - //###########| | | | | | | | Operation| Operation| Operation| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | | PerVector| | | | - //###########| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | - // pipeline v1, 1 wave - DeviceGemmXdl< F16, F16, F16, F32, Col, Col, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 64, 16, 16, 4, 8, 16, 16, 1, 1, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 8, true, 7, 1, 1, LoopScheduler::Default, PipelineVersion::v1> -#if CK_EXPERIMENTAL_INTER_WAVE_INSTANCES - // pipeline v1, 2 waves - , - DeviceGemmXdl< F16, F16, F16, F32, Col, Col, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 64, 16, 16, 4, 8, 16, 16, 1, 1, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 8, true, 7, 1, 1, LoopScheduler::Interwave, PipelineVersion::v1> -#endif -#if CK_EXPERIMENTAL_PIPELINE_V2_INSTANCES - // pipeline v2, 1 wave - , - DeviceGemmXdl< F16, F16, F16, F32, Col, Col, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 64, 16, 16, 4, 8, 16, 16, 1, 1, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 8, true, 7, 1, 1, LoopScheduler::Default, PipelineVersion::v2> -#endif - // clang-format on - >; - -void add_device_gemm_xdl_f16_f16_f16_km_nk_mn_instances( - std::vector>>& - instances) -{ - add_device_operation_instances(instances, device_gemm_xdl_f16_f16_f16_km_nk_mn_instances{}); - add_device_operation_instances(instances, - device_gemm_xdl_f16_f16_f16_km_nk_mn_irregular_tile_instances{}); -} - -} // namespace instance -} // namespace device -} // namespace tensor_operation -} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_f16_f16_f16_mk_kn_mn_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_f16_f16_f16_mk_kn_mn_instance.cpp deleted file mode 100644 index 1be70d6ca414f785ec2582fd90bc5d3d2617e4e3..0000000000000000000000000000000000000000 --- a/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_f16_f16_f16_mk_kn_mn_instance.cpp +++ /dev/null @@ -1,137 +0,0 @@ -// SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. - -#include - -#include "ck/ck.hpp" -#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" -#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" -#include "ck/tensor_operation/gpu/device/impl/device_gemm_xdl.hpp" -#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" - -namespace ck { -namespace tensor_operation { -namespace device { -namespace instance { - -using F16 = ck::half_t; -using F32 = float; - -using Row = ck::tensor_layout::gemm::RowMajor; -using Col = ck::tensor_layout::gemm::ColumnMajor; - -template -using S = ck::Sequence; - -using PassThrough = ck::tensor_operation::element_wise::PassThrough; - -static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default; -static constexpr auto GemmMNPadding = ck::tensor_operation::device::GemmSpecialization::MNPadding; - -// Compilation parameters for a[m, k] * b[k, n] = c[m, n] -using device_gemm_xdl_f16_f16_f16_mk_kn_mn_instances = - std::tuple< - // clang-format off - //##########| AData| BData| CData| AccData| ALayout| BLayout| CLayout| A| B| C| GEMM| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CThreadTransfer| CThreadTransfer| NumPrefetch| LoopScheduler| Pipeline| - //##########| Type| Type| Type| Type| | | | Elementwise| Elementwise| Elementwise|Specialization| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| SrcDstVectorDim| DstScalar| | | | - //##########| | | | | | | | Operation| Operation| Operation| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | | PerVector| | | | - //##########| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | - // pipeline v1, 1 wave - DeviceGemmXdl< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 256, 128, 4, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, 7, 1, 1, LoopScheduler::Default, PipelineVersion::v1>, - DeviceGemmXdl< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 128, 256, 4, 8, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, true, 7, 1, 1, LoopScheduler::Default, PipelineVersion::v1>, - DeviceGemmXdl< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 128, 128, 128, 4, 8, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, true, 7, 1, 1, LoopScheduler::Default, PipelineVersion::v1>, - DeviceGemmXdl< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 128, 128, 4, 8, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, 7, 1, 1, LoopScheduler::Default, PipelineVersion::v1>, - DeviceGemmXdl< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 128, 128, 64, 4, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, 7, 1, 1, LoopScheduler::Default, PipelineVersion::v1>, - DeviceGemmXdl< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 128, 64, 128, 4, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, true, 7, 1, 1, LoopScheduler::Default, PipelineVersion::v1>, - DeviceGemmXdl< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 128, 64, 4, 8, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, true, 7, 1, 1, LoopScheduler::Default, PipelineVersion::v1>, - DeviceGemmXdl< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 64, 128, 4, 8, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, 7, 1, 1, LoopScheduler::Default, PipelineVersion::v1>, - DeviceGemmXdl< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 128, 32, 256, 4, 8, 32, 32, 1, 4, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 8, true, 7, 1, 1, LoopScheduler::Default, PipelineVersion::v1>, - DeviceGemmXdl< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 128, 32, 128, 4, 8, 32, 32, 1, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, true, 7, 1, 1, LoopScheduler::Default, PipelineVersion::v1>, - DeviceGemmXdl< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 128, 32, 64, 4, 8, 32, 32, 1, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, 7, 1, 1, LoopScheduler::Default, PipelineVersion::v1>, - DeviceGemmXdl< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 64, 32, 32, 4, 8, 32, 32, 1, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, 7, 1, 1, LoopScheduler::Default, PipelineVersion::v1>, - DeviceGemmXdl< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 128, 16, 256, 4, 8, 16, 16, 1, 8, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 8, true, 7, 1, 1, LoopScheduler::Default, PipelineVersion::v1>, - DeviceGemmXdl< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 128, 16, 128, 4, 8, 16, 16, 1, 4, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, true, 7, 1, 1, LoopScheduler::Default, PipelineVersion::v1>, - DeviceGemmXdl< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 128, 16, 64, 4, 8, 16, 16, 1, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, 7, 1, 1, LoopScheduler::Default, PipelineVersion::v1>, - DeviceGemmXdl< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 128, 16, 32, 4, 8, 16, 16, 1, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, true, 7, 1, 1, LoopScheduler::Default, PipelineVersion::v1>, - DeviceGemmXdl< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 64, 16, 16, 4, 8, 16, 16, 1, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, true, 7, 1, 1, LoopScheduler::Default, PipelineVersion::v1> -#if CK_EXPERIMENTAL_INTER_WAVE_INSTANCES - // pipeline v1, 2 waves - , - DeviceGemmXdl< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 256, 128, 4, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, 7, 1, 1, LoopScheduler::Interwave, PipelineVersion::v1>, - DeviceGemmXdl< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 128, 256, 4, 8, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, true, 7, 1, 1, LoopScheduler::Interwave, PipelineVersion::v1>, - DeviceGemmXdl< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 128, 128, 128, 4, 8, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, true, 7, 1, 1, LoopScheduler::Interwave, PipelineVersion::v1>, - DeviceGemmXdl< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 128, 128, 4, 8, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, 7, 1, 1, LoopScheduler::Interwave, PipelineVersion::v1>, - DeviceGemmXdl< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 128, 128, 64, 4, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, 7, 1, 1, LoopScheduler::Interwave, PipelineVersion::v1>, - DeviceGemmXdl< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 128, 64, 128, 4, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, true, 7, 1, 1, LoopScheduler::Interwave, PipelineVersion::v1>, - DeviceGemmXdl< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 128, 64, 4, 8, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, true, 7, 1, 1, LoopScheduler::Interwave, PipelineVersion::v1>, - DeviceGemmXdl< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 64, 128, 4, 8, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, 7, 1, 1, LoopScheduler::Interwave, PipelineVersion::v1>, - DeviceGemmXdl< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 128, 32, 256, 4, 8, 32, 32, 1, 4, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 8, true, 7, 1, 1, LoopScheduler::Interwave, PipelineVersion::v1>, - DeviceGemmXdl< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 128, 32, 128, 4, 8, 32, 32, 1, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, true, 7, 1, 1, LoopScheduler::Interwave, PipelineVersion::v1>, - DeviceGemmXdl< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 128, 32, 64, 4, 8, 32, 32, 1, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, 7, 1, 1, LoopScheduler::Interwave, PipelineVersion::v1>, - DeviceGemmXdl< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 64, 32, 32, 4, 8, 32, 32, 1, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, 7, 1, 1, LoopScheduler::Interwave, PipelineVersion::v1>, - DeviceGemmXdl< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 128, 16, 256, 4, 8, 16, 16, 1, 8, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 8, true, 7, 1, 1, LoopScheduler::Interwave, PipelineVersion::v1>, - DeviceGemmXdl< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 128, 16, 128, 4, 8, 16, 16, 1, 4, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, true, 7, 1, 1, LoopScheduler::Interwave, PipelineVersion::v1>, - DeviceGemmXdl< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 128, 16, 64, 4, 8, 16, 16, 1, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, 7, 1, 1, LoopScheduler::Interwave, PipelineVersion::v1>, - DeviceGemmXdl< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 128, 16, 32, 4, 8, 16, 16, 1, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, true, 7, 1, 1, LoopScheduler::Interwave, PipelineVersion::v1>, - DeviceGemmXdl< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 64, 16, 16, 4, 8, 16, 16, 1, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, true, 7, 1, 1, LoopScheduler::Interwave, PipelineVersion::v1> -#endif -#if CK_EXPERIMENTAL_PIPELINE_V2_INSTANCES - // pipeline v2, 1 wave - , - DeviceGemmXdl< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 256, 128, 4, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, 7, 1, 1, LoopScheduler::Default, PipelineVersion::v2>, - DeviceGemmXdl< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 128, 256, 4, 8, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, true, 7, 1, 1, LoopScheduler::Default, PipelineVersion::v2>, - DeviceGemmXdl< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 128, 128, 128, 4, 8, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, true, 7, 1, 1, LoopScheduler::Default, PipelineVersion::v2>, - DeviceGemmXdl< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 128, 128, 4, 8, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, 7, 1, 1, LoopScheduler::Default, PipelineVersion::v2>, - DeviceGemmXdl< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 128, 128, 64, 4, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, 7, 1, 1, LoopScheduler::Default, PipelineVersion::v2>, - DeviceGemmXdl< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 128, 64, 128, 4, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, true, 7, 1, 1, LoopScheduler::Default, PipelineVersion::v2>, - DeviceGemmXdl< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 128, 64, 4, 8, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, true, 7, 1, 1, LoopScheduler::Default, PipelineVersion::v2>, - DeviceGemmXdl< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 64, 128, 4, 8, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, 7, 1, 1, LoopScheduler::Default, PipelineVersion::v2>, - DeviceGemmXdl< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 128, 32, 256, 4, 8, 32, 32, 1, 4, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 8, true, 7, 1, 1, LoopScheduler::Default, PipelineVersion::v2>, - DeviceGemmXdl< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 128, 32, 128, 4, 8, 32, 32, 1, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, true, 7, 1, 1, LoopScheduler::Default, PipelineVersion::v2>, - DeviceGemmXdl< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 128, 32, 64, 4, 8, 32, 32, 1, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, 7, 1, 1, LoopScheduler::Default, PipelineVersion::v2>, - DeviceGemmXdl< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 64, 32, 32, 4, 8, 32, 32, 1, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, 7, 1, 1, LoopScheduler::Default, PipelineVersion::v2>, - DeviceGemmXdl< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 128, 16, 256, 4, 8, 16, 16, 1, 8, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 8, true, 7, 1, 1, LoopScheduler::Default, PipelineVersion::v2>, - DeviceGemmXdl< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 128, 16, 128, 4, 8, 16, 16, 1, 4, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, true, 7, 1, 1, LoopScheduler::Default, PipelineVersion::v2>, - DeviceGemmXdl< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 128, 16, 64, 4, 8, 16, 16, 1, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, 7, 1, 1, LoopScheduler::Default, PipelineVersion::v2>, - DeviceGemmXdl< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 128, 16, 32, 4, 8, 16, 16, 1, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, true, 7, 1, 1, LoopScheduler::Default, PipelineVersion::v2>, - DeviceGemmXdl< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 64, 16, 16, 4, 8, 16, 16, 1, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, true, 7, 1, 1, LoopScheduler::Default, PipelineVersion::v2> -#endif - // clang-format on - >; - -// irregular tile size -using device_gemm_xdl_f16_f16_f16_mk_kn_mn_irregular_tile_instances = std::tuple< - // clang-format off - //###########| AData| BData| CData| AccData| ALayout| BLayout| CLayout| A| B| C| GEMM| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CThreadTransfer| CThreadTransfer| NumPrefetch| LoopScheduler| Pipeline| - //###########| Type| Type| Type| Type| | | | Elementwise| Elementwise| Elementwise|Specialization| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| SrcDstVectorDim| DstScalar| | | | - //###########| | | | | | | | Operation| Operation| Operation| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | | PerVector| | | | - //###########| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | - // pipeline v1, 1 wave - DeviceGemmXdl< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 64, 16, 16, 4, 8, 16, 16, 1, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, true, 7, 1, 1, LoopScheduler::Default, PipelineVersion::v1> -#if CK_EXPERIMENTAL_INTER_WAVE_INSTANCES - // pipeline v1, 2 waves - , - DeviceGemmXdl< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 64, 16, 16, 4, 8, 16, 16, 1, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, true, 7, 1, 1, LoopScheduler::Interwave, PipelineVersion::v1> -#endif -#if CK_EXPERIMENTAL_PIPELINE_V2_INSTANCES - // pipeline v2, 1 wave - , - DeviceGemmXdl< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 64, 16, 16, 4, 8, 16, 16, 1, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, true, 7, 1, 1, LoopScheduler::Default, PipelineVersion::v2> -#endif - // clang-format on - >; - -void add_device_gemm_xdl_f16_f16_f16_mk_kn_mn_instances( - std::vector>>& - instances) -{ - add_device_operation_instances(instances, device_gemm_xdl_f16_f16_f16_mk_kn_mn_instances{}); - add_device_operation_instances(instances, - device_gemm_xdl_f16_f16_f16_mk_kn_mn_irregular_tile_instances{}); -} - -} // namespace instance -} // namespace device -} // namespace tensor_operation -} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_f16_f16_f16_mk_nk_mn_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_f16_f16_f16_mk_nk_mn_instance.cpp deleted file mode 100644 index 6b8455ffa93fcceaf9d0e5d5595197f8bafddcbe..0000000000000000000000000000000000000000 --- a/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_f16_f16_f16_mk_nk_mn_instance.cpp +++ /dev/null @@ -1,130 +0,0 @@ -// SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. - -#include - -#include "ck/ck.hpp" -#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" -#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" -#include "ck/tensor_operation/gpu/device/impl/device_gemm_xdl.hpp" -#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" - -namespace ck { -namespace tensor_operation { -namespace device { -namespace instance { - -using F16 = ck::half_t; -using F32 = float; - -using Row = ck::tensor_layout::gemm::RowMajor; -using Col = ck::tensor_layout::gemm::ColumnMajor; - -template -using S = ck::Sequence; - -using PassThrough = ck::tensor_operation::element_wise::PassThrough; - -static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default; -static constexpr auto GemmMNPadding = ck::tensor_operation::device::GemmSpecialization::MNPadding; - -// Compilation parameters for a[m, k] * b[n, k] = c[m, n] -using device_gemm_xdl_f16_f16_f16_mk_nk_mn_instances = std::tuple< - // clang-format off - //###########| AData| BData| CData| AccData| ALayout| BLayout| CLayout| A| B| C| GEMM| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CThreadTransfer| CThreadTransfer| NumPrefetch| LoopScheduler| Pipeline| - //###########| Type| Type| Type| Type| | | | Elementwise| Elementwise| Elementwise|Specialization| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| SrcDstVectorDim| DstScalar| | | | - //###########| | | | | | | | Operation| Operation| Operation| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | | PerVector| | | | - //###########| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | - // pipeline v1, 1 wave - DeviceGemmXdl< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 256, 128, 4, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1, 1, LoopScheduler::Default, PipelineVersion::v1>, - DeviceGemmXdl< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 128, 256, 4, 8, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1, 1, LoopScheduler::Default, PipelineVersion::v1>, - DeviceGemmXdl< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 128, 128, 128, 4, 8, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1, 1, LoopScheduler::Default, PipelineVersion::v1>, - DeviceGemmXdl< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 128, 128, 4, 8, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1, 1, LoopScheduler::Default, PipelineVersion::v1>, - DeviceGemmXdl< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 128, 128, 64, 4, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1, 1, LoopScheduler::Default, PipelineVersion::v1>, - DeviceGemmXdl< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 128, 64, 128, 4, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1, 1, LoopScheduler::Default, PipelineVersion::v1>, - DeviceGemmXdl< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 64, 64, 64, 4, 8, 32, 32, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1, 1, LoopScheduler::Default, PipelineVersion::v1>, - DeviceGemmXdl< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 128, 64, 4, 8, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1, 1, LoopScheduler::Default, PipelineVersion::v1>, - DeviceGemmXdl< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 64, 128, 4, 8, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1, 1, LoopScheduler::Default, PipelineVersion::v1>, - DeviceGemmXdl< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 128, 128, 32, 4, 8, 32, 32, 2, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1, 1, LoopScheduler::Default, PipelineVersion::v1>, - DeviceGemmXdl< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 128, 32, 128, 4, 8, 32, 32, 1, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1, 1, LoopScheduler::Default, PipelineVersion::v1>, - DeviceGemmXdl< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 64, 64, 32, 4, 8, 32, 32, 2, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1, 1, LoopScheduler::Default, PipelineVersion::v1>, - DeviceGemmXdl< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 64, 32, 64, 4, 8, 32, 32, 1, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1, 1, LoopScheduler::Default, PipelineVersion::v1> -#if CK_EXPERIMENTAL_INTER_WAVE_INSTANCES - // pipeline v1, 2 waves - , - DeviceGemmXdl< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 256, 128, 4, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1, 1, LoopScheduler::Interwave, PipelineVersion::v1>, - DeviceGemmXdl< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 128, 256, 4, 8, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1, 1, LoopScheduler::Interwave, PipelineVersion::v1>, - DeviceGemmXdl< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 128, 128, 128, 4, 8, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1, 1, LoopScheduler::Interwave, PipelineVersion::v1>, - DeviceGemmXdl< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 128, 128, 4, 8, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1, 1, LoopScheduler::Interwave, PipelineVersion::v1>, - DeviceGemmXdl< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 128, 128, 64, 4, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1, 1, LoopScheduler::Interwave, PipelineVersion::v1>, - DeviceGemmXdl< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 128, 64, 128, 4, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1, 1, LoopScheduler::Interwave, PipelineVersion::v1>, - DeviceGemmXdl< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 64, 64, 64, 4, 8, 32, 32, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1, 1, LoopScheduler::Interwave, PipelineVersion::v1>, - DeviceGemmXdl< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 128, 64, 4, 8, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1, 1, LoopScheduler::Interwave, PipelineVersion::v1>, - DeviceGemmXdl< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 64, 128, 4, 8, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1, 1, LoopScheduler::Interwave, PipelineVersion::v1>, - DeviceGemmXdl< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 128, 128, 32, 4, 8, 32, 32, 2, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1, 1, LoopScheduler::Interwave, PipelineVersion::v1>, - DeviceGemmXdl< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 128, 32, 128, 4, 8, 32, 32, 1, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1, 1, LoopScheduler::Interwave, PipelineVersion::v1>, - DeviceGemmXdl< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 64, 64, 32, 4, 8, 32, 32, 2, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1, 1, LoopScheduler::Interwave, PipelineVersion::v1>, - DeviceGemmXdl< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 64, 32, 64, 4, 8, 32, 32, 1, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1, 1, LoopScheduler::Interwave, PipelineVersion::v1> -#endif -#if CK_EXPERIMENTAL_PIPELINE_V2_INSTANCES - // pipeline v2, 1 wave - , - DeviceGemmXdl< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 256, 128, 4, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1, 1, LoopScheduler::Default, PipelineVersion::v2>, - DeviceGemmXdl< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 128, 256, 4, 8, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1, 1, LoopScheduler::Default, PipelineVersion::v2>, - DeviceGemmXdl< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 128, 128, 128, 4, 8, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1, 1, LoopScheduler::Default, PipelineVersion::v2>, - DeviceGemmXdl< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 128, 128, 4, 8, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1, 1, LoopScheduler::Default, PipelineVersion::v2>, - DeviceGemmXdl< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 128, 128, 64, 4, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1, 1, LoopScheduler::Default, PipelineVersion::v2>, - DeviceGemmXdl< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 128, 64, 128, 4, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1, 1, LoopScheduler::Default, PipelineVersion::v2>, - DeviceGemmXdl< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 64, 64, 64, 4, 8, 32, 32, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1, 1, LoopScheduler::Default, PipelineVersion::v2>, - DeviceGemmXdl< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 128, 64, 4, 8, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1, 1, LoopScheduler::Default, PipelineVersion::v2>, - DeviceGemmXdl< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 64, 128, 4, 8, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1, 1, LoopScheduler::Default, PipelineVersion::v2>, - DeviceGemmXdl< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 128, 128, 32, 4, 8, 32, 32, 2, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1, 1, LoopScheduler::Default, PipelineVersion::v2>, - DeviceGemmXdl< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 128, 32, 128, 4, 8, 32, 32, 1, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1, 1, LoopScheduler::Default, PipelineVersion::v2>, - DeviceGemmXdl< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 64, 64, 32, 4, 8, 32, 32, 2, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1, 1, LoopScheduler::Default, PipelineVersion::v2>, - DeviceGemmXdl< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 64, 32, 64, 4, 8, 32, 32, 1, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1, 1, LoopScheduler::Default, PipelineVersion::v2> -#endif - // clang-format on - >; - -// irregular tile size -using device_gemm_xdl_f16_f16_f16_mk_nk_mn_irregular_tile_instances = std::tuple< - // clang-format off - //###########| AData| BData| CData| AccData| ALayout| BLayout| CLayout| A| B| C| GEMM| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CThreadTransfer| CThreadTransfer| NumPrefetch| LoopScheduler| Pipeline| - //###########| Type| Type| Type| Type| | | | Elementwise| Elementwise| Elementwise|Specialization| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| SrcDstVectorDim| DstScalar| | | | - //###########| | | | | | | | Operation| Operation| Operation| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | | PerVector| | | | - //###########| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | - // pipeline v1, 1 wave - DeviceGemmXdl< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 64, 64, 32, 4, 8, 32, 32, 2, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1, 1, LoopScheduler::Default, PipelineVersion::v1>, - DeviceGemmXdl< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 256, 128, 144, 8, 8, 16, 16, 2, 9, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<8, 8, 4>, S<1, 0, 2>, S<1, 0, 2>, 2, 2, 2, true, 7, 1, 1, LoopScheduler::Default, PipelineVersion::v1>, - DeviceGemmXdl< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 256, 128, 144, 4, 8, 16, 16, 2, 9, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 4>, S<1, 0, 2>, S<1, 0, 2>, 2, 2, 2, true, 7, 1, 1, LoopScheduler::Default, PipelineVersion::v1> -#if CK_EXPERIMENTAL_INTER_WAVE_INSTANCES - // pipeline v1, 2 waves - , - DeviceGemmXdl< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 64, 64, 32, 4, 8, 32, 32, 2, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1, 1, LoopScheduler::Interwave, PipelineVersion::v1>, - DeviceGemmXdl< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 256, 128, 144, 8, 8, 16, 16, 2, 9, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<8, 8, 4>, S<1, 0, 2>, S<1, 0, 2>, 2, 2, 2, true, 7, 1, 1, LoopScheduler::Interwave, PipelineVersion::v1>, - DeviceGemmXdl< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 256, 128, 144, 4, 8, 16, 16, 2, 9, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 4>, S<1, 0, 2>, S<1, 0, 2>, 2, 2, 2, true, 7, 1, 1, LoopScheduler::Interwave, PipelineVersion::v1> -#endif -#if CK_EXPERIMENTAL_PIPELINE_V2_INSTANCES - // pipeline v2, 1 wave - , - DeviceGemmXdl< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 64, 64, 32, 4, 8, 32, 32, 2, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1, 1, LoopScheduler::Default, PipelineVersion::v2>, - DeviceGemmXdl< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 256, 128, 144, 8, 8, 16, 16, 2, 9, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<8, 8, 4>, S<1, 0, 2>, S<1, 0, 2>, 2, 2, 2, true, 7, 1, 1, LoopScheduler::Default, PipelineVersion::v2>, - DeviceGemmXdl< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 256, 128, 144, 4, 8, 16, 16, 2, 9, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 4>, S<1, 0, 2>, S<1, 0, 2>, 2, 2, 2, true, 7, 1, 1, LoopScheduler::Default, PipelineVersion::v2> -#endif - // clang-format on - >; - -void add_device_gemm_xdl_f16_f16_f16_mk_nk_mn_instances( - std::vector>>& - instances) -{ - add_device_operation_instances(instances, device_gemm_xdl_f16_f16_f16_mk_nk_mn_instances{}); - add_device_operation_instances(instances, - device_gemm_xdl_f16_f16_f16_mk_nk_mn_irregular_tile_instances{}); -} - -} // namespace instance -} // namespace device -} // namespace tensor_operation -} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_f32_f32_f32_km_kn_mn_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_f32_f32_f32_km_kn_mn_instance.cpp index b9e28e3d7ef70b2e3aa785b048280fb4959ce312..a590413accba4e9e0dddd8df2952cd9d3fdb8b35 100644 --- a/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_f32_f32_f32_km_kn_mn_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_f32_f32_f32_km_kn_mn_instance.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #include diff --git a/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_f32_f32_f32_km_nk_mn_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_f32_f32_f32_km_nk_mn_instance.cpp index 2b1a5a57bb8b22b9235fdd02fe3f9da659b13c0d..1d010d1b0792dfc62f0a1de235558ab8e4e71fd0 100644 --- a/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_f32_f32_f32_km_nk_mn_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_f32_f32_f32_km_nk_mn_instance.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #include diff --git a/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_f32_f32_f32_mk_kn_mn_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_f32_f32_f32_mk_kn_mn_instance.cpp index 301d3b55b59c0a6818b27570fe1c9a7b9ff44946..f108b753421598c376eefb8cfa52b0a4f1522920 100644 --- a/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_f32_f32_f32_mk_kn_mn_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_f32_f32_f32_mk_kn_mn_instance.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #include diff --git a/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_f32_f32_f32_mk_nk_mn_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_f32_f32_f32_mk_nk_mn_instance.cpp index cd16f35ff233fb04c79d50888f54b806e79e611a..b0b4bc012df347cf3e487d96ecef2e03826f5e77 100644 --- a/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_f32_f32_f32_mk_nk_mn_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_f32_f32_f32_mk_nk_mn_instance.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #include diff --git a/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_f64_f64_f64_km_kn_mn_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_f64_f64_f64_km_kn_mn_instance.cpp index 39166698473dbd05010f960e82366012ffaa1863..df3bd94fcafa5062b597584c3aadf2772c376f39 100644 --- a/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_f64_f64_f64_km_kn_mn_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_f64_f64_f64_km_kn_mn_instance.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #include diff --git a/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_f64_f64_f64_km_nk_mn_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_f64_f64_f64_km_nk_mn_instance.cpp index 0a623034ef65a1cfd0d220b9ad2253ab0b298620..73b4e776668c969fb6ddcd38f74bd1a74fea48c8 100644 --- a/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_f64_f64_f64_km_nk_mn_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_f64_f64_f64_km_nk_mn_instance.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #include diff --git a/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_f64_f64_f64_mk_kn_mn_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_f64_f64_f64_mk_kn_mn_instance.cpp index 5ef8d08de90a65ae04f552fc87447e3b7887bc92..76137a1c3e1e2cb82c9e15ec9bc71e4abe76391e 100644 --- a/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_f64_f64_f64_mk_kn_mn_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_f64_f64_f64_mk_kn_mn_instance.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #include diff --git a/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_f64_f64_f64_mk_nk_mn_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_f64_f64_f64_mk_nk_mn_instance.cpp index c9557bae893e1affc0f2138aaa9cfca206e267d1..f0158d8f3d51fbb773988d02c14c329e9243ebf9 100644 --- a/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_f64_f64_f64_mk_nk_mn_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_f64_f64_f64_mk_nk_mn_instance.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #include diff --git a/library/src/tensor_operation_instance/gpu/gemm_add_add_fastgelu/CMakeLists.txt b/library/src/tensor_operation_instance/gpu/gemm_add_add_fastgelu/CMakeLists.txt index bbf81a5fa25d096a87c9ece48f117e38ac70550a..9028829fe9f32a73fe636861e7913c6177a1cc43 100644 --- a/library/src/tensor_operation_instance/gpu/gemm_add_add_fastgelu/CMakeLists.txt +++ b/library/src/tensor_operation_instance/gpu/gemm_add_add_fastgelu/CMakeLists.txt @@ -1,6 +1,8 @@ +if(DTYPES MATCHES "fp16" OR NOT DEFINED DTYPES) add_instance_library(device_gemm_add_add_fastgelu_instance device_gemm_add_add_fastgelu_xdl_c_shuffle_f16_f16_f16_f16_f16_km_kn_mn_mn_mn_instance.cpp device_gemm_add_add_fastgelu_xdl_c_shuffle_f16_f16_f16_f16_f16_km_nk_mn_mn_mn_instance.cpp device_gemm_add_add_fastgelu_xdl_c_shuffle_f16_f16_f16_f16_f16_mk_kn_mn_mn_mn_instance.cpp device_gemm_add_add_fastgelu_xdl_c_shuffle_f16_f16_f16_f16_f16_mk_nk_mn_mn_mn_instance.cpp ) +endif() diff --git a/library/src/tensor_operation_instance/gpu/gemm_add_add_fastgelu/device_gemm_add_add_fastgelu_xdl_c_shuffle_f16_f16_f16_f16_f16_km_kn_mn_mn_mn_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm_add_add_fastgelu/device_gemm_add_add_fastgelu_xdl_c_shuffle_f16_f16_f16_f16_f16_km_kn_mn_mn_mn_instance.cpp index 463e0865c0ae6980c534005d1ace62750645a74c..125fbc21a2b9c141adc8d11b0727667ca062fada 100644 --- a/library/src/tensor_operation_instance/gpu/gemm_add_add_fastgelu/device_gemm_add_add_fastgelu_xdl_c_shuffle_f16_f16_f16_f16_f16_km_kn_mn_mn_mn_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/gemm_add_add_fastgelu/device_gemm_add_add_fastgelu_xdl_c_shuffle_f16_f16_f16_f16_f16_km_kn_mn_mn_mn_instance.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #include @@ -36,6 +36,17 @@ static constexpr auto GemmMNKPadding = ck::tensor_operation::device::GemmSpecial // e = elementwise((a * b), d0, d1) // outout: e[m, n] // input: a[k, m], b[k, n], d0[m, n], d1[m, n] +using device_gemm_add_add_fastgelu_xdl_c_shuffle_f16_f16_f16_f16_f16_km_kn_mn_mn_mn_generic_instance = + std::tuple< + // clang-format off + //##############################| A| B| Ds| E| AData| BData| AccData| CShuffle| DsData| EData| A| B| CDE| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| LoopScheduler| Pipeline| + //##############################| Layout| Layout| Layout| Layout| Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Specialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| | | + //##############################| | | | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| | | + //##############################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + // pipeline v1, 1 wave + DeviceGemmMultipleD_Xdl_CShuffle< Col, Row, Row_Row_Tuple, Row, F16, F16, F32, F32, F16_F16_Tuple, F16, PassThrough, PassThrough, AddAddFastGelu, GemmMNKPadding, 1, 128, 128, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 1, 1, 1, S<1, 16, 1, 8>, 8, LoopScheduler::Default, PipelineVersion::v1> + // clang-format on + >; using device_gemm_add_add_fastgelu_xdl_c_shuffle_f16_f16_f16_f16_f16_km_kn_mn_mn_mn_instances = std::tuple< // clang-format off @@ -139,6 +150,9 @@ void add_device_gemm_add_add_fastgelu_xdl_c_shuffle_f16_f16_f16_f16_f16_km_kn_mn PassThrough, AddAddFastGelu>>>& instances) { + add_device_operation_instances( + instances, + device_gemm_add_add_fastgelu_xdl_c_shuffle_f16_f16_f16_f16_f16_km_kn_mn_mn_mn_generic_instance{}); add_device_operation_instances( instances, device_gemm_add_add_fastgelu_xdl_c_shuffle_f16_f16_f16_f16_f16_km_kn_mn_mn_mn_instances{}); diff --git a/library/src/tensor_operation_instance/gpu/gemm_add_add_fastgelu/device_gemm_add_add_fastgelu_xdl_c_shuffle_f16_f16_f16_f16_f16_km_nk_mn_mn_mn_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm_add_add_fastgelu/device_gemm_add_add_fastgelu_xdl_c_shuffle_f16_f16_f16_f16_f16_km_nk_mn_mn_mn_instance.cpp index b71ff1b9972a97276a0fedf0106814a54ba19a0c..cc33692d7378178903fe2b0a9a43180571d70276 100644 --- a/library/src/tensor_operation_instance/gpu/gemm_add_add_fastgelu/device_gemm_add_add_fastgelu_xdl_c_shuffle_f16_f16_f16_f16_f16_km_nk_mn_mn_mn_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/gemm_add_add_fastgelu/device_gemm_add_add_fastgelu_xdl_c_shuffle_f16_f16_f16_f16_f16_km_nk_mn_mn_mn_instance.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #include @@ -36,6 +36,17 @@ static constexpr auto GemmMNKPadding = ck::tensor_operation::device::GemmSpecial // e = elementwise((a * b), d0, d1) // outout: e[m, n] // input: a[k, m], b[n, k], d0[m, n], d1[m, n] +using device_gemm_add_add_fastgelu_xdl_c_shuffle_f16_f16_f16_f16_f16_km_nk_mn_mn_mn_generic_instance = + std::tuple< + // clang-format off + //##############################| A| B| Ds| E| AData| BData| AccData| CShuffle| DsData| EData| A| B| CDE| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| LoopScheduler| Pipeline| + //##############################| Layout| Layout| Layout| Layout| Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Specialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| | | + //##############################| | | | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| | | + //##############################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + // pipeline v1, 1 wave + DeviceGemmMultipleD_Xdl_CShuffle< Col, Col, Row_Row_Tuple, Row, F16, F16, F32, F32, F16_F16_Tuple, F16, PassThrough, PassThrough, AddAddFastGelu, GemmMNKPadding, 1, 128, 128, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 8, LoopScheduler::Default, PipelineVersion::v1> + // clang-format on + >; using device_gemm_add_add_fastgelu_xdl_c_shuffle_f16_f16_f16_f16_f16_km_nk_mn_mn_mn_instances = std::tuple< // clang-format off @@ -139,6 +150,9 @@ void add_device_gemm_add_add_fastgelu_xdl_c_shuffle_f16_f16_f16_f16_f16_km_nk_mn PassThrough, AddAddFastGelu>>>& instances) { + add_device_operation_instances( + instances, + device_gemm_add_add_fastgelu_xdl_c_shuffle_f16_f16_f16_f16_f16_km_nk_mn_mn_mn_generic_instance{}); add_device_operation_instances( instances, device_gemm_add_add_fastgelu_xdl_c_shuffle_f16_f16_f16_f16_f16_km_nk_mn_mn_mn_instances{}); diff --git a/library/src/tensor_operation_instance/gpu/gemm_add_add_fastgelu/device_gemm_add_add_fastgelu_xdl_c_shuffle_f16_f16_f16_f16_f16_mk_kn_mn_mn_mn_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm_add_add_fastgelu/device_gemm_add_add_fastgelu_xdl_c_shuffle_f16_f16_f16_f16_f16_mk_kn_mn_mn_mn_instance.cpp index 9060c9b1b084512f038ccd5c5b0c3573ac32f0a9..704787a080ee0af7ea74f9f75a290a69cf73adf0 100644 --- a/library/src/tensor_operation_instance/gpu/gemm_add_add_fastgelu/device_gemm_add_add_fastgelu_xdl_c_shuffle_f16_f16_f16_f16_f16_mk_kn_mn_mn_mn_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/gemm_add_add_fastgelu/device_gemm_add_add_fastgelu_xdl_c_shuffle_f16_f16_f16_f16_f16_mk_kn_mn_mn_mn_instance.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #include @@ -36,6 +36,17 @@ static constexpr auto GemmMNKPadding = ck::tensor_operation::device::GemmSpecial // e = elementwise((a * b), d0, d1) // outout: e[m, n] // input: a[m, k], b[k, n], d0[m, n], d1[m, n] +using device_gemm_add_add_fastgelu_xdl_c_shuffle_f16_f16_f16_f16_f16_mk_kn_mn_mn_mn_generic_instance = + std::tuple< + // clang-format off + //##############################| A| B| Ds| E| AData| BData| AccData| CShuffle| DsData| EData| A| B| CDE| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| LoopScheduler| Pipeline| + //##############################| Layout| Layout| Layout| Layout| Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Specialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| | | + //##############################| | | | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| | | + //##############################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + // pipeline v1, 1 wave + DeviceGemmMultipleD_Xdl_CShuffle< Row, Row, Row_Row_Tuple, Row, F16, F16, F32, F32, F16_F16_Tuple, F16, PassThrough, PassThrough, AddAddFastGelu, GemmMNKPadding, 1, 128, 128, 128, 32, 8, 2, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<1, 16, 1, 8>, 8, LoopScheduler::Default, PipelineVersion::v1> + // clang-format on + >; using device_gemm_add_add_fastgelu_xdl_c_shuffle_f16_f16_f16_f16_f16_mk_kn_mn_mn_mn_instances = std::tuple< // clang-format off @@ -139,6 +150,9 @@ void add_device_gemm_add_add_fastgelu_xdl_c_shuffle_f16_f16_f16_f16_f16_mk_kn_mn PassThrough, AddAddFastGelu>>>& instances) { + add_device_operation_instances( + instances, + device_gemm_add_add_fastgelu_xdl_c_shuffle_f16_f16_f16_f16_f16_mk_kn_mn_mn_mn_generic_instance{}); add_device_operation_instances( instances, device_gemm_add_add_fastgelu_xdl_c_shuffle_f16_f16_f16_f16_f16_mk_kn_mn_mn_mn_instances{}); diff --git a/library/src/tensor_operation_instance/gpu/gemm_add_add_fastgelu/device_gemm_add_add_fastgelu_xdl_c_shuffle_f16_f16_f16_f16_f16_mk_nk_mn_mn_mn_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm_add_add_fastgelu/device_gemm_add_add_fastgelu_xdl_c_shuffle_f16_f16_f16_f16_f16_mk_nk_mn_mn_mn_instance.cpp index 81cf01d6a9d70346e8ea2387220fb03258e6e580..d64c9ec5e15aa85c54d5c852755be63ca0597327 100644 --- a/library/src/tensor_operation_instance/gpu/gemm_add_add_fastgelu/device_gemm_add_add_fastgelu_xdl_c_shuffle_f16_f16_f16_f16_f16_mk_nk_mn_mn_mn_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/gemm_add_add_fastgelu/device_gemm_add_add_fastgelu_xdl_c_shuffle_f16_f16_f16_f16_f16_mk_nk_mn_mn_mn_instance.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #include @@ -36,6 +36,17 @@ static constexpr auto GemmMNKPadding = ck::tensor_operation::device::GemmSpecial // e = elementwise((a * b), d0, d1) // outout: e[m, n] // input: a[m, k], b[n, k], d0[m, n], d1[m ,n] +using device_gemm_add_add_fastgelu_xdl_c_shuffle_f16_f16_f16_f16_f16_mk_nk_mn_mn_mn_generic_instance = + std::tuple< + // clang-format off + //##############################| A| B| Ds| E| AData| BData| AccData| CShuffle| DsData| EData| A| B| CDE| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| LoopScheduler| Pipeline| + //##############################| Layout| Layout| Layout| Layout| Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Specialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| | | + //##############################| | | | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| | | + //##############################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + // pipeline v1, 1 wave + DeviceGemmMultipleD_Xdl_CShuffle< Row, Col, Row_Row_Tuple, Row, F16, F16, F32, F32, F16_F16_Tuple, F16, PassThrough, PassThrough, AddAddFastGelu, GemmMNKPadding, 1, 128, 128, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 8, LoopScheduler::Default, PipelineVersion::v1> + // clang-format on + >; using device_gemm_add_add_fastgelu_xdl_c_shuffle_f16_f16_f16_f16_f16_mk_nk_mn_mn_mn_instances = std::tuple< // clang-format off @@ -130,6 +141,9 @@ void add_device_gemm_add_add_fastgelu_xdl_c_shuffle_f16_f16_f16_f16_f16_mk_nk_mn PassThrough, AddAddFastGelu>>>& instances) { + add_device_operation_instances( + instances, + device_gemm_add_add_fastgelu_xdl_c_shuffle_f16_f16_f16_f16_f16_mk_nk_mn_mn_mn_generic_instance{}); add_device_operation_instances( instances, device_gemm_add_add_fastgelu_xdl_c_shuffle_f16_f16_f16_f16_f16_mk_nk_mn_mn_mn_instances{}); diff --git a/library/src/tensor_operation_instance/gpu/gemm_add_fastgelu/CMakeLists.txt b/library/src/tensor_operation_instance/gpu/gemm_add_fastgelu/CMakeLists.txt index 0beb10e37978cf64d51a9463ca9c01c278a51273..1085966807a50afa5ab23fc5a3e4d234bea0f286 100644 --- a/library/src/tensor_operation_instance/gpu/gemm_add_fastgelu/CMakeLists.txt +++ b/library/src/tensor_operation_instance/gpu/gemm_add_fastgelu/CMakeLists.txt @@ -1,6 +1,8 @@ +if(DTYPES MATCHES "fp16" OR NOT DEFINED DTYPES) add_instance_library(device_gemm_add_fastgelu_instance device_gemm_add_fastgelu_xdl_c_shuffle_f16_f16_f16_f16_km_kn_mn_mn_instance.cpp device_gemm_add_fastgelu_xdl_c_shuffle_f16_f16_f16_f16_km_nk_mn_mn_instance.cpp device_gemm_add_fastgelu_xdl_c_shuffle_f16_f16_f16_f16_mk_kn_mn_mn_instance.cpp device_gemm_add_fastgelu_xdl_c_shuffle_f16_f16_f16_f16_mk_nk_mn_mn_instance.cpp ) +endif() diff --git a/library/src/tensor_operation_instance/gpu/gemm_add_fastgelu/device_gemm_add_fastgelu_xdl_c_shuffle_f16_f16_f16_f16_km_kn_mn_mn_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm_add_fastgelu/device_gemm_add_fastgelu_xdl_c_shuffle_f16_f16_f16_f16_km_kn_mn_mn_instance.cpp index 4da85cc46eb4ded15bcd63dd7b0d89864637d2b0..e68bd8e7e4aa032dc56007e22aba0f90ca0c6af8 100644 --- a/library/src/tensor_operation_instance/gpu/gemm_add_fastgelu/device_gemm_add_fastgelu_xdl_c_shuffle_f16_f16_f16_f16_km_kn_mn_mn_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/gemm_add_fastgelu/device_gemm_add_fastgelu_xdl_c_shuffle_f16_f16_f16_f16_km_kn_mn_mn_instance.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" #include "ck/library/tensor_operation_instance/device_operation_instance_factory.hpp" @@ -21,6 +21,17 @@ static constexpr auto GemmMNKPadding = ck::tensor_operation::device::GemmSpecial // e = elementwise((a * b), d0) // outout: e[m, n] // input: a[k, m], b[k, n], d0[m, n] +using device_gemm_add_fastgelu_xdl_c_shuffle_f16_f16_f16_f16_km_kn_mn_mn_generic_instance = + std::tuple< + // clang-format off + //##############################| A| B| Ds| E| AData| BData| AccData| CShuffle| DsData| EData| A| B| CDE| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| LoopScheduler| Pipeline| + //##############################| Layout| Layout| Layout| Layout| Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Specialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| | | + //##############################| | | | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| | | + //##############################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + // pipeline v1, 1 wave + DeviceGemmMultipleD_Xdl_CShuffle< Col, Row, Row_Tuple, Row, F16, F16, F32, F32, F16_Tuple, F16, PassThrough, PassThrough, AddFastGelu, GemmMNKPadding, 1, 64, 16, 16, 32, 8, 8, 16, 16, 1, 1, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 16, 1, 4>, 1, LoopScheduler::Default, PipelineVersion::v1> + // clang-format on + >; using device_gemm_add_fastgelu_xdl_c_shuffle_f16_f16_f16_f16_km_kn_mn_mn_instances = std::tuple< // clang-format off //##############################| A| B| Ds| E| AData| BData| AccData| CShuffle| DsData| EData| A| B| CDE| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| LoopScheduler| Pipeline| @@ -123,6 +134,9 @@ void add_device_gemm_add_fastgelu_xdl_c_shuffle_f16_f16_f16_f16_km_kn_mn_mn_inst PassThrough, AddFastGelu>>>& instances) { + add_device_operation_instances( + instances, + device_gemm_add_fastgelu_xdl_c_shuffle_f16_f16_f16_f16_km_kn_mn_mn_generic_instance{}); add_device_operation_instances( instances, device_gemm_add_fastgelu_xdl_c_shuffle_f16_f16_f16_f16_km_kn_mn_mn_instances{}); add_device_operation_instances( diff --git a/library/src/tensor_operation_instance/gpu/gemm_add_fastgelu/device_gemm_add_fastgelu_xdl_c_shuffle_f16_f16_f16_f16_km_nk_mn_mn_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm_add_fastgelu/device_gemm_add_fastgelu_xdl_c_shuffle_f16_f16_f16_f16_km_nk_mn_mn_instance.cpp index ab83e4baabf80242bfbf569fb13b5d39a3686ad7..5aaa2e8fe509a18ac360859f0cdcf67535b75fd8 100644 --- a/library/src/tensor_operation_instance/gpu/gemm_add_fastgelu/device_gemm_add_fastgelu_xdl_c_shuffle_f16_f16_f16_f16_km_nk_mn_mn_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/gemm_add_fastgelu/device_gemm_add_fastgelu_xdl_c_shuffle_f16_f16_f16_f16_km_nk_mn_mn_instance.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" #include "ck/library/tensor_operation_instance/device_operation_instance_factory.hpp" @@ -21,6 +21,17 @@ static constexpr auto GemmMNKPadding = ck::tensor_operation::device::GemmSpecial // e = elementwise((a * b), d0, d1) // outout: e[m, n] // input: a[k, m], b[n, k], d0[m, n], d1[m, n] +using device_gemm_add_fastgelu_xdl_c_shuffle_f16_f16_f16_f16_km_nk_mn_mn_generic_instance = + std::tuple< + // clang-format off + //##############################| A| B| Ds| E| AData| BData| AccData| CShuffle| DsData| EData| A| B| CDE| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| LoopScheduler| Pipeline| + //##############################| Layout| Layout| Layout| Layout| Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Specialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| | | + //##############################| | | | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| | | + //##############################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + // pipeline v1, 1 wave + DeviceGemmMultipleD_Xdl_CShuffle< Col, Col, Row_Tuple, Row, F16, F16, F32, F32, F16_Tuple, F16, PassThrough, PassThrough, AddFastGelu, GemmMNKPadding, 1, 64, 16, 16, 32, 8, 8, 16, 16, 1, 1, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 8, 1, 1, 1, S<1, 16, 1, 4>, 1, LoopScheduler::Default, PipelineVersion::v1> + // clang-format on + >; using device_gemm_add_fastgelu_xdl_c_shuffle_f16_f16_f16_f16_km_nk_mn_mn_instances = std::tuple< // clang-format off //##############################| A| B| Ds| E| AData| BData| AccData| CShuffle| DsData| EData| A| B| CDE| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| LoopScheduler| Pipeline| @@ -123,6 +134,9 @@ void add_device_gemm_add_fastgelu_xdl_c_shuffle_f16_f16_f16_f16_km_nk_mn_mn_inst PassThrough, AddFastGelu>>>& instances) { + add_device_operation_instances( + instances, + device_gemm_add_fastgelu_xdl_c_shuffle_f16_f16_f16_f16_km_nk_mn_mn_generic_instance{}); add_device_operation_instances( instances, device_gemm_add_fastgelu_xdl_c_shuffle_f16_f16_f16_f16_km_nk_mn_mn_instances{}); add_device_operation_instances( diff --git a/library/src/tensor_operation_instance/gpu/gemm_add_fastgelu/device_gemm_add_fastgelu_xdl_c_shuffle_f16_f16_f16_f16_mk_kn_mn_mn_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm_add_fastgelu/device_gemm_add_fastgelu_xdl_c_shuffle_f16_f16_f16_f16_mk_kn_mn_mn_instance.cpp index a4cd3fadbe93f4231388e9efccbe28a2c69a6f1b..7a2a3dbaf3d82cbed71e91e2bebe8fb3f3dd4ea0 100644 --- a/library/src/tensor_operation_instance/gpu/gemm_add_fastgelu/device_gemm_add_fastgelu_xdl_c_shuffle_f16_f16_f16_f16_mk_kn_mn_mn_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/gemm_add_fastgelu/device_gemm_add_fastgelu_xdl_c_shuffle_f16_f16_f16_f16_mk_kn_mn_mn_instance.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" #include "ck/library/tensor_operation_instance/device_operation_instance_factory.hpp" @@ -21,6 +21,17 @@ static constexpr auto GemmMNKPadding = ck::tensor_operation::device::GemmSpecial // e = elementwise((a * b), d0, d1) // outout: e[m, n] // input: a[m, k], b[k, n], d0[m, n], d1[m, n] +using device_gemm_add_fastgelu_xdl_c_shuffle_f16_f16_f16_f16_mk_kn_mn_mn_generic_instance = + std::tuple< + // clang-format off + //##############################| A| B| Ds| E| AData| BData| AccData| CShuffle| DsData| EData| A| B| CDE| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| LoopScheduler| Pipeline| + //##############################| Layout| Layout| Layout| Layout| Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Specialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| | | + //##############################| | | | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| | | + //##############################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + // pipeline v1, 1 wave + DeviceGemmMultipleD_Xdl_CShuffle< Row, Row, Row_Tuple, Row, F16, F16, F32, F32, F16_Tuple, F16, PassThrough, PassThrough, AddFastGelu, GemmMNKPadding, 1, 64, 16, 16, 32, 8, 8, 16, 16, 1, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 8, 1, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 16, 1, 4>, 1, LoopScheduler::Default, PipelineVersion::v1> + // clang-format on + >; using device_gemm_add_fastgelu_xdl_c_shuffle_f16_f16_f16_f16_mk_kn_mn_mn_instances = std::tuple< // clang-format off //##############################| A| B| Ds| E| AData| BData| AccData| CShuffle| DsData| EData| A| B| CDE| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| LoopScheduler| Pipeline| @@ -123,6 +134,9 @@ void add_device_gemm_add_fastgelu_xdl_c_shuffle_f16_f16_f16_f16_mk_kn_mn_mn_inst PassThrough, AddFastGelu>>>& instances) { + add_device_operation_instances( + instances, + device_gemm_add_fastgelu_xdl_c_shuffle_f16_f16_f16_f16_mk_kn_mn_mn_generic_instance{}); add_device_operation_instances( instances, device_gemm_add_fastgelu_xdl_c_shuffle_f16_f16_f16_f16_mk_kn_mn_mn_instances{}); add_device_operation_instances( diff --git a/library/src/tensor_operation_instance/gpu/gemm_add_fastgelu/device_gemm_add_fastgelu_xdl_c_shuffle_f16_f16_f16_f16_mk_nk_mn_mn_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm_add_fastgelu/device_gemm_add_fastgelu_xdl_c_shuffle_f16_f16_f16_f16_mk_nk_mn_mn_instance.cpp index 207e76ffe5f18f9a4d64a0083c27829933e7a5fa..fa3360997886de27f75b89457d8e1ba9ae4471d1 100644 --- a/library/src/tensor_operation_instance/gpu/gemm_add_fastgelu/device_gemm_add_fastgelu_xdl_c_shuffle_f16_f16_f16_f16_mk_nk_mn_mn_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/gemm_add_fastgelu/device_gemm_add_fastgelu_xdl_c_shuffle_f16_f16_f16_f16_mk_nk_mn_mn_instance.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" #include "ck/library/tensor_operation_instance/device_operation_instance_factory.hpp" @@ -21,6 +21,17 @@ static constexpr auto GemmMNKPadding = ck::tensor_operation::device::GemmSpecial // e = elementwise((a * b), d0, d1) // outout: e[m, n] // input: a[m, k], b[n, k], d0[m, n], d1[m ,n] +using device_gemm_add_fastgelu_xdl_c_shuffle_f16_f16_f16_f16_mk_nk_mn_mn_generic_instance = + std::tuple< + // clang-format off + //##############################| A| B| Ds| E| AData| BData| AccData| CShuffle| DsData| EData| A| B| CDE| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| LoopScheduler| Pipeline| + //##############################| Layout| Layout| Layout| Layout| Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Specialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| | | + //##############################| | | | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| | | + //##############################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + // pipeline v1, 1 wave + DeviceGemmMultipleD_Xdl_CShuffle< Row, Col, Row_Tuple, Row, F16, F16, F32, F32, F16_Tuple, F16, PassThrough, PassThrough, AddFastGelu, GemmMNKPadding, 1, 64, 16, 16, 32, 8, 8, 16, 16, 1, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 8, 1, 1, 1, S<1, 16, 1, 4>, 1, LoopScheduler::Default, PipelineVersion::v1> + // clang-format on + >; using device_gemm_add_fastgelu_xdl_c_shuffle_f16_f16_f16_f16_mk_nk_mn_mn_instances = std::tuple< // clang-format off //##############################| A| B| Ds| E| AData| BData| AccData| CShuffle| DsData| EData| A| B| CDE| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| LoopScheduler| Pipeline| @@ -114,6 +125,9 @@ void add_device_gemm_add_fastgelu_xdl_c_shuffle_f16_f16_f16_f16_mk_nk_mn_mn_inst PassThrough, AddFastGelu>>>& instances) { + add_device_operation_instances( + instances, + device_gemm_add_fastgelu_xdl_c_shuffle_f16_f16_f16_f16_mk_nk_mn_mn_generic_instance{}); add_device_operation_instances( instances, device_gemm_add_fastgelu_xdl_c_shuffle_f16_f16_f16_f16_mk_nk_mn_mn_instances{}); add_device_operation_instances( diff --git a/library/src/tensor_operation_instance/gpu/gemm_add_multiply/device_gemm_add_multiply_xdl_c_shuffle_f16_f16_f16_f16_f16_km_kn_mn_mn_mn_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm_add_multiply/device_gemm_add_multiply_xdl_c_shuffle_f16_f16_f16_f16_f16_km_kn_mn_mn_mn_instance.cpp index 3f30937ff5b3b7235efc93dc6287a494ed783dee..d9ae7817914dbcb7dcae38254afef6f02e54a4da 100644 --- a/library/src/tensor_operation_instance/gpu/gemm_add_multiply/device_gemm_add_multiply_xdl_c_shuffle_f16_f16_f16_f16_f16_km_kn_mn_mn_mn_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/gemm_add_multiply/device_gemm_add_multiply_xdl_c_shuffle_f16_f16_f16_f16_f16_km_kn_mn_mn_mn_instance.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #include diff --git a/library/src/tensor_operation_instance/gpu/gemm_add_multiply/device_gemm_add_multiply_xdl_c_shuffle_f16_f16_f16_f16_f16_km_nk_mn_mn_mn_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm_add_multiply/device_gemm_add_multiply_xdl_c_shuffle_f16_f16_f16_f16_f16_km_nk_mn_mn_mn_instance.cpp index d91e6c63bea06d70513106cf7eac171b892498b9..27c6cbe8ab71398f9c3bc03bb977c71af397836d 100644 --- a/library/src/tensor_operation_instance/gpu/gemm_add_multiply/device_gemm_add_multiply_xdl_c_shuffle_f16_f16_f16_f16_f16_km_nk_mn_mn_mn_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/gemm_add_multiply/device_gemm_add_multiply_xdl_c_shuffle_f16_f16_f16_f16_f16_km_nk_mn_mn_mn_instance.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #include diff --git a/library/src/tensor_operation_instance/gpu/gemm_add_multiply/device_gemm_add_multiply_xdl_c_shuffle_f16_f16_f16_f16_f16_mk_kn_mn_mn_mn_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm_add_multiply/device_gemm_add_multiply_xdl_c_shuffle_f16_f16_f16_f16_f16_mk_kn_mn_mn_mn_instance.cpp index 143321542353e4d84d7275f375d9fd49103bc8cb..ff5f0e94a1732e3b36bc9295c4f1ce9a24fa467f 100644 --- a/library/src/tensor_operation_instance/gpu/gemm_add_multiply/device_gemm_add_multiply_xdl_c_shuffle_f16_f16_f16_f16_f16_mk_kn_mn_mn_mn_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/gemm_add_multiply/device_gemm_add_multiply_xdl_c_shuffle_f16_f16_f16_f16_f16_mk_kn_mn_mn_mn_instance.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #include diff --git a/library/src/tensor_operation_instance/gpu/gemm_add_multiply/device_gemm_add_multiply_xdl_c_shuffle_f16_f16_f16_f16_f16_mk_nk_mn_mn_mn_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm_add_multiply/device_gemm_add_multiply_xdl_c_shuffle_f16_f16_f16_f16_f16_mk_nk_mn_mn_mn_instance.cpp index 09acc7c0f75bcd76412755a590605833c2bc3b5e..dba625e0ada1d4f09cb26a5860849a707a35cb97 100644 --- a/library/src/tensor_operation_instance/gpu/gemm_add_multiply/device_gemm_add_multiply_xdl_c_shuffle_f16_f16_f16_f16_f16_mk_nk_mn_mn_mn_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/gemm_add_multiply/device_gemm_add_multiply_xdl_c_shuffle_f16_f16_f16_f16_f16_mk_nk_mn_mn_mn_instance.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #include diff --git a/library/src/tensor_operation_instance/gpu/gemm_add_relu_add_layernorm/CMakeLists.txt b/library/src/tensor_operation_instance/gpu/gemm_add_relu_add_layernorm/CMakeLists.txt index 97693a2566fc440eb8a7e71bd02251700191dfda..6079f901300fa0c95300e6b2fed5ae2ae34657af 100644 --- a/library/src/tensor_operation_instance/gpu/gemm_add_relu_add_layernorm/CMakeLists.txt +++ b/library/src/tensor_operation_instance/gpu/gemm_add_relu_add_layernorm/CMakeLists.txt @@ -1,6 +1,8 @@ +if(DTYPES MATCHES "fp16" OR NOT DEFINED DTYPES) add_instance_library(device_gemm_add_relu_add_layernorm_instance device_gemm_add_relu_add_xdl_c_shuffle_layernorm_f16_km_kn_mn_mn_mn_instance.cpp device_gemm_add_relu_add_xdl_c_shuffle_layernorm_f16_km_nk_mn_mn_mn_instance.cpp device_gemm_add_relu_add_xdl_c_shuffle_layernorm_f16_mk_kn_mn_mn_mn_instance.cpp device_gemm_add_relu_add_xdl_c_shuffle_layernorm_f16_mk_nk_mn_mn_mn_instance.cpp ) +endif() diff --git a/library/src/tensor_operation_instance/gpu/gemm_add_relu_add_layernorm/device_gemm_add_relu_add_xdl_c_shuffle_layernorm_f16_km_kn_mn_mn_mn_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm_add_relu_add_layernorm/device_gemm_add_relu_add_xdl_c_shuffle_layernorm_f16_km_kn_mn_mn_mn_instance.cpp index 47b8d23424d1dd59099441d3da8d80c1b889daf2..28a452c1a1124a326337cf4c21183035294fda1d 100644 --- a/library/src/tensor_operation_instance/gpu/gemm_add_relu_add_layernorm/device_gemm_add_relu_add_xdl_c_shuffle_layernorm_f16_km_kn_mn_mn_mn_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/gemm_add_relu_add_layernorm/device_gemm_add_relu_add_xdl_c_shuffle_layernorm_f16_km_kn_mn_mn_mn_instance.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #include "ck/ck.hpp" #include "ck/tensor_operation/gpu/device/tensor_layout.hpp" diff --git a/library/src/tensor_operation_instance/gpu/gemm_add_relu_add_layernorm/device_gemm_add_relu_add_xdl_c_shuffle_layernorm_f16_km_nk_mn_mn_mn_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm_add_relu_add_layernorm/device_gemm_add_relu_add_xdl_c_shuffle_layernorm_f16_km_nk_mn_mn_mn_instance.cpp index efa030ec4952acabef2a5fbd879adc07eeb2ecbf..13366238d6d9008d7ed425485670df5229f71cee 100644 --- a/library/src/tensor_operation_instance/gpu/gemm_add_relu_add_layernorm/device_gemm_add_relu_add_xdl_c_shuffle_layernorm_f16_km_nk_mn_mn_mn_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/gemm_add_relu_add_layernorm/device_gemm_add_relu_add_xdl_c_shuffle_layernorm_f16_km_nk_mn_mn_mn_instance.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #include "ck/ck.hpp" #include "ck/tensor_operation/gpu/device/tensor_layout.hpp" diff --git a/library/src/tensor_operation_instance/gpu/gemm_add_relu_add_layernorm/device_gemm_add_relu_add_xdl_c_shuffle_layernorm_f16_mk_kn_mn_mn_mn_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm_add_relu_add_layernorm/device_gemm_add_relu_add_xdl_c_shuffle_layernorm_f16_mk_kn_mn_mn_mn_instance.cpp index f2735020e65bdbeda2c31df3c451d16903a0e320..8a4889ee83eabcae6466bdd5dd67d268c9d54287 100644 --- a/library/src/tensor_operation_instance/gpu/gemm_add_relu_add_layernorm/device_gemm_add_relu_add_xdl_c_shuffle_layernorm_f16_mk_kn_mn_mn_mn_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/gemm_add_relu_add_layernorm/device_gemm_add_relu_add_xdl_c_shuffle_layernorm_f16_mk_kn_mn_mn_mn_instance.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #include "ck/ck.hpp" #include "ck/tensor_operation/gpu/device/tensor_layout.hpp" diff --git a/library/src/tensor_operation_instance/gpu/gemm_add_relu_add_layernorm/device_gemm_add_relu_add_xdl_c_shuffle_layernorm_f16_mk_nk_mn_mn_mn_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm_add_relu_add_layernorm/device_gemm_add_relu_add_xdl_c_shuffle_layernorm_f16_mk_nk_mn_mn_mn_instance.cpp index 7d4aae928b31d5c086ea717ad4be492a577b8cb1..fc3cbcf9055892bfcf63dcf6d3d3c4515026ca71 100644 --- a/library/src/tensor_operation_instance/gpu/gemm_add_relu_add_layernorm/device_gemm_add_relu_add_xdl_c_shuffle_layernorm_f16_mk_nk_mn_mn_mn_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/gemm_add_relu_add_layernorm/device_gemm_add_relu_add_xdl_c_shuffle_layernorm_f16_mk_nk_mn_mn_mn_instance.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #include "ck/ck.hpp" #include "ck/tensor_operation/gpu/device/tensor_layout.hpp" diff --git a/library/src/tensor_operation_instance/gpu/gemm_bias_add_reduce/device_gemm_bias_add_mean_squaremean_xdl_cshuffle_f16_f16_f16_f32_f32_km_kn_mn_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm_bias_add_reduce/device_gemm_bias_add_mean_squaremean_xdl_cshuffle_f16_f16_f16_f32_f32_km_kn_mn_instance.cpp index e8747af4828a2492f6db8e23b575de4efb7ce90e..bfb95bce85cb8696e645ec6b5626abf84456916f 100644 --- a/library/src/tensor_operation_instance/gpu/gemm_bias_add_reduce/device_gemm_bias_add_mean_squaremean_xdl_cshuffle_f16_f16_f16_f32_f32_km_kn_mn_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/gemm_bias_add_reduce/device_gemm_bias_add_mean_squaremean_xdl_cshuffle_f16_f16_f16_f32_f32_km_kn_mn_instance.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #include diff --git a/library/src/tensor_operation_instance/gpu/gemm_bias_add_reduce/device_gemm_bias_add_mean_squaremean_xdl_cshuffle_f16_f16_f16_f32_f32_km_nk_mn_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm_bias_add_reduce/device_gemm_bias_add_mean_squaremean_xdl_cshuffle_f16_f16_f16_f32_f32_km_nk_mn_instance.cpp index ed54c3a9bf58346a8829f3d7f275c6668e00ed1d..d0352339cf38fecf9c3b3d4fd8a91729ed53495f 100644 --- a/library/src/tensor_operation_instance/gpu/gemm_bias_add_reduce/device_gemm_bias_add_mean_squaremean_xdl_cshuffle_f16_f16_f16_f32_f32_km_nk_mn_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/gemm_bias_add_reduce/device_gemm_bias_add_mean_squaremean_xdl_cshuffle_f16_f16_f16_f32_f32_km_nk_mn_instance.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #include diff --git a/library/src/tensor_operation_instance/gpu/gemm_bias_add_reduce/device_gemm_bias_add_mean_squaremean_xdl_cshuffle_f16_f16_f16_f32_f32_mk_kn_mn_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm_bias_add_reduce/device_gemm_bias_add_mean_squaremean_xdl_cshuffle_f16_f16_f16_f32_f32_mk_kn_mn_instance.cpp index da7eae637bfb16580b4ec98e479768c2649adad7..d5b298ab217b4269beca858f0d54e319bf6dbf4f 100644 --- a/library/src/tensor_operation_instance/gpu/gemm_bias_add_reduce/device_gemm_bias_add_mean_squaremean_xdl_cshuffle_f16_f16_f16_f32_f32_mk_kn_mn_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/gemm_bias_add_reduce/device_gemm_bias_add_mean_squaremean_xdl_cshuffle_f16_f16_f16_f32_f32_mk_kn_mn_instance.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #include diff --git a/library/src/tensor_operation_instance/gpu/gemm_bias_add_reduce/device_gemm_bias_add_mean_squaremean_xdl_cshuffle_f16_f16_f16_f32_f32_mk_nk_mn_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm_bias_add_reduce/device_gemm_bias_add_mean_squaremean_xdl_cshuffle_f16_f16_f16_f32_f32_mk_nk_mn_instance.cpp index 34345095e0c2489f44437ea58f58b2891cc91ff1..80c8f018f416154d0b2ccd26cb9f8822233a43e4 100644 --- a/library/src/tensor_operation_instance/gpu/gemm_bias_add_reduce/device_gemm_bias_add_mean_squaremean_xdl_cshuffle_f16_f16_f16_f32_f32_mk_nk_mn_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/gemm_bias_add_reduce/device_gemm_bias_add_mean_squaremean_xdl_cshuffle_f16_f16_f16_f32_f32_mk_nk_mn_instance.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #include diff --git a/library/src/tensor_operation_instance/gpu/gemm_bilinear/CMakeLists.txt b/library/src/tensor_operation_instance/gpu/gemm_bilinear/CMakeLists.txt index 426edeed74877522b5ce19b20310193ee391a29e..aef8fe86dd43b9717664eb5ac62e9b15f47c9855 100644 --- a/library/src/tensor_operation_instance/gpu/gemm_bilinear/CMakeLists.txt +++ b/library/src/tensor_operation_instance/gpu/gemm_bilinear/CMakeLists.txt @@ -1,3 +1,4 @@ +if(DTYPES MATCHES "fp16" OR NOT DEFINED DTYPES) add_instance_library(device_gemm_bilinear_instance device_gemm_bilinear_xdl_c_shuffle_f16_f16_f16_f16_km_kn_mn_mn_instance.cpp device_gemm_bilinear_xdl_c_shuffle_f16_f16_f16_f16_km_nk_mn_mn_instance.cpp @@ -8,3 +9,4 @@ add_instance_library(device_gemm_bilinear_instance device_gemm_bilinear_wmma_c_shuffle_i8_i8_i8_i8_mk_kn_mn_mn_instance.cpp device_gemm_bilinear_wmma_c_shuffle_i8_i8_i8_i8_mk_nk_mn_mn_instance.cpp ) +endif() diff --git a/library/src/tensor_operation_instance/gpu/gemm_bilinear/device_gemm_bilinear_xdl_c_shuffle_f16_f16_f16_f16_km_kn_mn_mn_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm_bilinear/device_gemm_bilinear_xdl_c_shuffle_f16_f16_f16_f16_km_kn_mn_mn_instance.cpp index 55461dfba783de45a2991836535b3b652cc49a44..74ec9e1f8d28a77b0f78181cd6c05ea93b140e16 100644 --- a/library/src/tensor_operation_instance/gpu/gemm_bilinear/device_gemm_bilinear_xdl_c_shuffle_f16_f16_f16_f16_km_kn_mn_mn_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/gemm_bilinear/device_gemm_bilinear_xdl_c_shuffle_f16_f16_f16_f16_km_kn_mn_mn_instance.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #include diff --git a/library/src/tensor_operation_instance/gpu/gemm_bilinear/device_gemm_bilinear_xdl_c_shuffle_f16_f16_f16_f16_km_nk_mn_mn_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm_bilinear/device_gemm_bilinear_xdl_c_shuffle_f16_f16_f16_f16_km_nk_mn_mn_instance.cpp index 405e69975ca9ffec97999bce38fd73d953e139ed..eb98b3e7e871a1c5f4129032031175443ddc4de0 100644 --- a/library/src/tensor_operation_instance/gpu/gemm_bilinear/device_gemm_bilinear_xdl_c_shuffle_f16_f16_f16_f16_km_nk_mn_mn_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/gemm_bilinear/device_gemm_bilinear_xdl_c_shuffle_f16_f16_f16_f16_km_nk_mn_mn_instance.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #include diff --git a/library/src/tensor_operation_instance/gpu/gemm_bilinear/device_gemm_bilinear_xdl_c_shuffle_f16_f16_f16_f16_mk_kn_mn_mn_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm_bilinear/device_gemm_bilinear_xdl_c_shuffle_f16_f16_f16_f16_mk_kn_mn_mn_instance.cpp index 9af31b3a129d2d1bf1ff6f32d0ee8e252cbdfa29..5f4a90125ac827ccfa406b5bf81f11313ce54258 100644 --- a/library/src/tensor_operation_instance/gpu/gemm_bilinear/device_gemm_bilinear_xdl_c_shuffle_f16_f16_f16_f16_mk_kn_mn_mn_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/gemm_bilinear/device_gemm_bilinear_xdl_c_shuffle_f16_f16_f16_f16_mk_kn_mn_mn_instance.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #include diff --git a/library/src/tensor_operation_instance/gpu/gemm_bilinear/device_gemm_bilinear_xdl_c_shuffle_f16_f16_f16_f16_mk_nk_mn_mn_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm_bilinear/device_gemm_bilinear_xdl_c_shuffle_f16_f16_f16_f16_mk_nk_mn_mn_instance.cpp index 841b7a1d47ce48b8486f0af04ff80571e4601d1a..38e3897d6a552d8a2f797b9f3b4baaa98636a53d 100644 --- a/library/src/tensor_operation_instance/gpu/gemm_bilinear/device_gemm_bilinear_xdl_c_shuffle_f16_f16_f16_f16_mk_nk_mn_mn_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/gemm_bilinear/device_gemm_bilinear_xdl_c_shuffle_f16_f16_f16_f16_mk_nk_mn_mn_instance.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #include diff --git a/library/src/tensor_operation_instance/gpu/gemm_fastgelu/CMakeLists.txt b/library/src/tensor_operation_instance/gpu/gemm_fastgelu/CMakeLists.txt index 17d27ab150f520e5405274171f9433d9e8cc5dda..772373dcb099af4ed4ea11327022f4f91e9ceca6 100644 --- a/library/src/tensor_operation_instance/gpu/gemm_fastgelu/CMakeLists.txt +++ b/library/src/tensor_operation_instance/gpu/gemm_fastgelu/CMakeLists.txt @@ -1,6 +1,8 @@ +if(DTYPES MATCHES "fp16" OR NOT DEFINED DTYPES) add_instance_library(device_gemm_fastgelu_instance device_gemm_fastgelu_xdl_c_shuffle_f16_f16_f16_km_kn_mn_instance.cpp device_gemm_fastgelu_xdl_c_shuffle_f16_f16_f16_km_nk_mn_instance.cpp device_gemm_fastgelu_xdl_c_shuffle_f16_f16_f16_mk_kn_mn_instance.cpp device_gemm_fastgelu_xdl_c_shuffle_f16_f16_f16_mk_nk_mn_instance.cpp ) +endif() diff --git a/library/src/tensor_operation_instance/gpu/gemm_fastgelu/device_gemm_fastgelu_xdl_c_shuffle_f16_f16_f16_km_kn_mn_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm_fastgelu/device_gemm_fastgelu_xdl_c_shuffle_f16_f16_f16_km_kn_mn_instance.cpp index 9f7f643beb6832594bdc393777e5e08663a1f94e..803c44c7f53352109c28510c75f59e98165d6d8e 100644 --- a/library/src/tensor_operation_instance/gpu/gemm_fastgelu/device_gemm_fastgelu_xdl_c_shuffle_f16_f16_f16_km_kn_mn_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/gemm_fastgelu/device_gemm_fastgelu_xdl_c_shuffle_f16_f16_f16_km_kn_mn_instance.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" #include "ck/library/tensor_operation_instance/device_operation_instance_factory.hpp" @@ -21,6 +21,16 @@ static constexpr auto GemmMNKPadding = ck::tensor_operation::device::GemmSpecial // e = elementwise((a * b)) // outout: e[m, n] // input: a[k, m], b[k, n] +using device_gemm_fastgelu_xdl_c_shuffle_f16_f16_f16_km_kn_mn_generic_instance = std::tuple< + // clang-format off + //##############################| A| B| Ds| E| AData| BData| AccData| CShuffle| DsData| EData| A| B| CDE| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| LoopScheduler| Pipeline| + //##############################| Layout| Layout| Layout| Layout| Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Specialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| | | + //##############################| | | | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| | | + //##############################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + // pipeline v1, 1 wave + DeviceGemmMultipleD_Xdl_CShuffle< Col, Row, Empty_Tuple, Row, F16, F16, F32, F32, Empty_Tuple, F16, PassThrough, PassThrough, FastGelu, GemmMNKPadding, 1, 64, 16, 16, 32, 8, 8, 16, 16, 1, 1, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 16, 1, 4>, 1, LoopScheduler::Default, PipelineVersion::v1> + // clang-format on + >; using device_gemm_fastgelu_xdl_c_shuffle_f16_f16_f16_km_kn_mn_instances = std::tuple< // clang-format off //##############################| A| B| Ds| E| AData| BData| AccData| CShuffle| DsData| EData| A| B| CDE| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| LoopScheduler| Pipeline| @@ -122,6 +132,8 @@ void add_device_gemm_fastgelu_xdl_c_shuffle_f16_f16_f16_km_kn_mn_instances( PassThrough, FastGelu>>>& instances) { + add_device_operation_instances( + instances, device_gemm_fastgelu_xdl_c_shuffle_f16_f16_f16_km_kn_mn_generic_instance{}); add_device_operation_instances( instances, device_gemm_fastgelu_xdl_c_shuffle_f16_f16_f16_km_kn_mn_instances{}); add_device_operation_instances( diff --git a/library/src/tensor_operation_instance/gpu/gemm_fastgelu/device_gemm_fastgelu_xdl_c_shuffle_f16_f16_f16_km_nk_mn_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm_fastgelu/device_gemm_fastgelu_xdl_c_shuffle_f16_f16_f16_km_nk_mn_instance.cpp index c8e9f35d240da45fb902a2d02cd73d2eab5167e1..9b9ef3db24cb28e0783b5003678beeab01dd22f1 100644 --- a/library/src/tensor_operation_instance/gpu/gemm_fastgelu/device_gemm_fastgelu_xdl_c_shuffle_f16_f16_f16_km_nk_mn_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/gemm_fastgelu/device_gemm_fastgelu_xdl_c_shuffle_f16_f16_f16_km_nk_mn_instance.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" #include "ck/library/tensor_operation_instance/device_operation_instance_factory.hpp" @@ -21,6 +21,16 @@ static constexpr auto GemmMNKPadding = ck::tensor_operation::device::GemmSpecial // e = elementwise((a * b)) // outout: e[m, n] // input: a[k, m], b[k, n] +using device_gemm_fastgelu_xdl_c_shuffle_f16_f16_f16_km_nk_mn_generic_instance = std::tuple< + // clang-format off + //##############################| A| B| Ds| E| AData| BData| AccData| CShuffle| DsData| EData| A| B| CDE| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| LoopScheduler| Pipeline| + //##############################| Layout| Layout| Layout| Layout| Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Specialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| | | + //##############################| | | | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| | | + //##############################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + // pipeline v1, 1 wave + DeviceGemmMultipleD_Xdl_CShuffle< Col, Col, Empty_Tuple, Row, F16, F16, F32, F32, Empty_Tuple, F16, PassThrough, PassThrough, FastGelu, GemmMNKPadding, 1, 64, 16, 16, 32, 8, 8, 16, 16, 1, 1, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 8, 1, 1, 1, S<1, 16, 1, 4>, 1, LoopScheduler::Default, PipelineVersion::v1> + // clang-format on + >; using device_gemm_fastgelu_xdl_c_shuffle_f16_f16_f16_km_nk_mn_instances = std::tuple< // clang-format off //##############################| A| B| Ds| E| AData| BData| AccData| CShuffle| DsData| EData| A| B| CDE| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| LoopScheduler| Pipeline| @@ -122,6 +132,8 @@ void add_device_gemm_fastgelu_xdl_c_shuffle_f16_f16_f16_km_nk_mn_instances( PassThrough, FastGelu>>>& instances) { + add_device_operation_instances( + instances, device_gemm_fastgelu_xdl_c_shuffle_f16_f16_f16_km_nk_mn_generic_instance{}); add_device_operation_instances( instances, device_gemm_fastgelu_xdl_c_shuffle_f16_f16_f16_km_nk_mn_instances{}); add_device_operation_instances( diff --git a/library/src/tensor_operation_instance/gpu/gemm_fastgelu/device_gemm_fastgelu_xdl_c_shuffle_f16_f16_f16_mk_kn_mn_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm_fastgelu/device_gemm_fastgelu_xdl_c_shuffle_f16_f16_f16_mk_kn_mn_instance.cpp index 5f804d45a560ac3b7b519cb3ab1840214e392d39..1a0b6c9d1c62d1aeb378ca140fd09959d213533a 100644 --- a/library/src/tensor_operation_instance/gpu/gemm_fastgelu/device_gemm_fastgelu_xdl_c_shuffle_f16_f16_f16_mk_kn_mn_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/gemm_fastgelu/device_gemm_fastgelu_xdl_c_shuffle_f16_f16_f16_mk_kn_mn_instance.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" #include "ck/library/tensor_operation_instance/device_operation_instance_factory.hpp" @@ -21,6 +21,16 @@ static constexpr auto GemmMNKPadding = ck::tensor_operation::device::GemmSpecial // e = elementwise((a * b)) // outout: e[m, n] // input: a[m, k], b[k, n] +using device_gemm_fastgelu_xdl_c_shuffle_f16_f16_f16_mk_kn_mn_generic_instance = std::tuple< + // clang-format off + //##############################| A| B| Ds| E| AData| BData| AccData| CShuffle| DsData| EData| A| B| CDE| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| LoopScheduler| Pipeline| + //##############################| Layout| Layout| Layout| Layout| Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Specialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| | | + //##############################| | | | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| | | + //##############################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + // pipeline v1, 1 wave + DeviceGemmMultipleD_Xdl_CShuffle< Row, Row, Empty_Tuple, Row, F16, F16, F32, F32, Empty_Tuple, F16, PassThrough, PassThrough, FastGelu, GemmMNKPadding, 1, 64, 16, 16, 32, 8, 8, 16, 16, 1, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 8, 1, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 16, 1, 4>, 1, LoopScheduler::Default, PipelineVersion::v1> + // clang-format on + >; using device_gemm_fastgelu_xdl_c_shuffle_f16_f16_f16_mk_kn_mn_instances = std::tuple< // clang-format off //##############################| A| B| Ds| E| AData| BData| AccData| CShuffle| DsData| EData| A| B| CDE| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| LoopScheduler| Pipeline| @@ -122,6 +132,8 @@ void add_device_gemm_fastgelu_xdl_c_shuffle_f16_f16_f16_mk_kn_mn_instances( PassThrough, FastGelu>>>& instances) { + add_device_operation_instances( + instances, device_gemm_fastgelu_xdl_c_shuffle_f16_f16_f16_mk_kn_mn_generic_instance{}); add_device_operation_instances( instances, device_gemm_fastgelu_xdl_c_shuffle_f16_f16_f16_mk_kn_mn_instances{}); add_device_operation_instances( diff --git a/library/src/tensor_operation_instance/gpu/gemm_fastgelu/device_gemm_fastgelu_xdl_c_shuffle_f16_f16_f16_mk_nk_mn_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm_fastgelu/device_gemm_fastgelu_xdl_c_shuffle_f16_f16_f16_mk_nk_mn_instance.cpp index 60cb138f5654bb6f2b10ea5ba0a03191a49906ab..18b1c0e993131b62d0692769090facd14034ea8e 100644 --- a/library/src/tensor_operation_instance/gpu/gemm_fastgelu/device_gemm_fastgelu_xdl_c_shuffle_f16_f16_f16_mk_nk_mn_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/gemm_fastgelu/device_gemm_fastgelu_xdl_c_shuffle_f16_f16_f16_mk_nk_mn_instance.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" #include "ck/library/tensor_operation_instance/device_operation_instance_factory.hpp" @@ -21,6 +21,16 @@ static constexpr auto GemmMNKPadding = ck::tensor_operation::device::GemmSpecial // e = elementwise((a * b)) // outout: e[m, n] // input: a[m, k], b[n, k] +using device_gemm_fastgelu_xdl_c_shuffle_f16_f16_f16_mk_nk_mn_generic_instance = std::tuple< + // clang-format off + //##############################| A| B| Ds| E| AData| BData| AccData| CShuffle| DsData| EData| A| B| CDE| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| LoopScheduler| Pipeline| + //##############################| Layout| Layout| Layout| Layout| Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Specialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| | | + //##############################| | | | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| | | + //##############################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + // pipeline v1, 1 wave + DeviceGemmMultipleD_Xdl_CShuffle< Row, Col, Empty_Tuple, Row, F16, F16, F32, F32, Empty_Tuple, F16, PassThrough, PassThrough, FastGelu, GemmMNKPadding, 1, 64, 16, 16, 32, 8, 8, 16, 16, 1, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 8, 1, 1, 1, S<1, 16, 1, 4>, 1, LoopScheduler::Default, PipelineVersion::v1> + // clang-format on + >; using device_gemm_fastgelu_xdl_c_shuffle_f16_f16_f16_mk_nk_mn_instances = std::tuple< // clang-format off //##############################| A| B| Ds| E| AData| BData| AccData| CShuffle| DsData| EData| A| B| CDE| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| LoopScheduler| Pipeline| @@ -113,6 +123,8 @@ void add_device_gemm_fastgelu_xdl_c_shuffle_f16_f16_f16_mk_nk_mn_instances( PassThrough, FastGelu>>>& instances) { + add_device_operation_instances( + instances, device_gemm_fastgelu_xdl_c_shuffle_f16_f16_f16_mk_nk_mn_generic_instance{}); add_device_operation_instances( instances, device_gemm_fastgelu_xdl_c_shuffle_f16_f16_f16_mk_nk_mn_instances{}); add_device_operation_instances( diff --git a/library/src/tensor_operation_instance/gpu/gemm_multiply_add/CMakeLists.txt b/library/src/tensor_operation_instance/gpu/gemm_multiply_add/CMakeLists.txt new file mode 100644 index 0000000000000000000000000000000000000000..a4f744345542f2ec8d4b7ed2ab13d62d7a427125 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/gemm_multiply_add/CMakeLists.txt @@ -0,0 +1,7 @@ +add_instance_library(device_gemm_multiply_add_instance + device_gemm_multiply_add_xdl_c_shuffle_f16_f16_f16_f16_f16_mk_kn_mn_mn_mn_instance.cpp + device_gemm_multiply_add_xdl_c_shuffle_f16_f16_f16_f16_f16_mk_nk_mn_mn_mn_instance.cpp + + device_gemm_multiply_add_xdl_c_shuffle_f16_f8_f32_f32_f16_mk_kn_mn_mn_mn_instance.cpp + device_gemm_multiply_add_xdl_c_shuffle_f16_f8_f32_f32_f16_mk_nk_mn_mn_mn_instance.cpp +) diff --git a/library/src/tensor_operation_instance/gpu/gemm_multiply_add/device_gemm_multiply_add_xdl_c_shuffle_f16_f16_f16_f16_f16_mk_kn_mn_mn_mn_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm_multiply_add/device_gemm_multiply_add_xdl_c_shuffle_f16_f16_f16_f16_f16_mk_kn_mn_mn_mn_instance.cpp new file mode 100644 index 0000000000000000000000000000000000000000..bb38b33405b0e47a7c958fbfee024537dd64b9c3 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/gemm_multiply_add/device_gemm_multiply_add_xdl_c_shuffle_f16_f16_f16_f16_f16_mk_kn_mn_mn_mn_instance.cpp @@ -0,0 +1,83 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_gemm_multiple_d_xdl_cshuffle.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" + +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +using F16 = ck::half_t; +using F32 = float; +using F16_Tuple = ck::Tuple; + +using Row = ck::tensor_layout::gemm::RowMajor; +using Col = ck::tensor_layout::gemm::ColumnMajor; +using Row_Tuple = ck::Tuple; + +template +using S = ck::Sequence; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; +using MultiplyAdd = ck::tensor_operation::element_wise::MultiplyAdd; + +static constexpr auto GemmMNKPadding = ck::tensor_operation::device::GemmSpecialization::MNKPadding; + +using device_gemm_multiply_add_xdl_c_shuffle_f16_f16_f16_f16_f16_mk_kn_mn_mn_mn_instances = + std::tuple< + // clang-format off + // M/N/K padding + //##############################| A| B| Ds| E| AData| BData| AccData| CShuffle| DsData| EData| A| B| CDE| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| + //##############################| Layout| Layout| Layout| Layout| Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Specialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| + //##############################| | | | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| + //##############################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceGemmMultipleD_Xdl_CShuffle< Row, Row, Row_Tuple, Row, F16, F16, F32, F16, F16_Tuple, F16, PassThrough, PassThrough, MultiplyAdd, GemmMNKPadding, 1, 256, 256, 128, 32, 8, 2, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGemmMultipleD_Xdl_CShuffle< Row, Row, Row_Tuple, Row, F16, F16, F32, F16, F16_Tuple, F16, PassThrough, PassThrough, MultiplyAdd, GemmMNKPadding, 1, 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGemmMultipleD_Xdl_CShuffle< Row, Row, Row_Tuple, Row, F16, F16, F32, F16, F16_Tuple, F16, PassThrough, PassThrough, MultiplyAdd, GemmMNKPadding, 1, 256, 128, 256, 32, 8, 2, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGemmMultipleD_Xdl_CShuffle< Row, Row, Row_Tuple, Row, F16, F16, F32, F16, F16_Tuple, F16, PassThrough, PassThrough, MultiplyAdd, GemmMNKPadding, 1, 256, 128, 256, 32, 8, 8, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGemmMultipleD_Xdl_CShuffle< Row, Row, Row_Tuple, Row, F16, F16, F32, F16, F16_Tuple, F16, PassThrough, PassThrough, MultiplyAdd, GemmMNKPadding, 1, 128, 128, 128, 32, 8, 2, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<1, 16, 1, 8>, 8>, + DeviceGemmMultipleD_Xdl_CShuffle< Row, Row, Row_Tuple, Row, F16, F16, F32, F16, F16_Tuple, F16, PassThrough, PassThrough, MultiplyAdd, GemmMNKPadding, 1, 128, 128, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 1, 1, 1, S<1, 16, 1, 8>, 8>, + DeviceGemmMultipleD_Xdl_CShuffle< Row, Row, Row_Tuple, Row, F16, F16, F32, F16, F16_Tuple, F16, PassThrough, PassThrough, MultiplyAdd, GemmMNKPadding, 1, 256, 128, 128, 32, 8, 2, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGemmMultipleD_Xdl_CShuffle< Row, Row, Row_Tuple, Row, F16, F16, F32, F16, F16_Tuple, F16, PassThrough, PassThrough, MultiplyAdd, GemmMNKPadding, 1, 256, 128, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGemmMultipleD_Xdl_CShuffle< Row, Row, Row_Tuple, Row, F16, F16, F32, F16, F16_Tuple, F16, PassThrough, PassThrough, MultiplyAdd, GemmMNKPadding, 1, 128, 128, 64, 32, 8, 2, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<8, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<1, 32, 1, 4>, 8>, + DeviceGemmMultipleD_Xdl_CShuffle< Row, Row, Row_Tuple, Row, F16, F16, F32, F16, F16_Tuple, F16, PassThrough, PassThrough, MultiplyAdd, GemmMNKPadding, 1, 128, 128, 64, 32, 8, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, 1, 1, S<1, 32, 1, 4>, 8>, + DeviceGemmMultipleD_Xdl_CShuffle< Row, Row, Row_Tuple, Row, F16, F16, F32, F16, F16_Tuple, F16, PassThrough, PassThrough, MultiplyAdd, GemmMNKPadding, 1, 128, 64, 128, 32, 8, 2, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<1, 16, 1, 8>, 8>, + DeviceGemmMultipleD_Xdl_CShuffle< Row, Row, Row_Tuple, Row, F16, F16, F32, F16, F16_Tuple, F16, PassThrough, PassThrough, MultiplyAdd, GemmMNKPadding, 1, 128, 64, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 1, 1, 1, S<1, 16, 1, 8>, 8>, + DeviceGemmMultipleD_Xdl_CShuffle< Row, Row, Row_Tuple, Row, F16, F16, F32, F16, F16_Tuple, F16, PassThrough, PassThrough, MultiplyAdd, GemmMNKPadding, 1, 256, 128, 64, 32, 8, 2, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<16,16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGemmMultipleD_Xdl_CShuffle< Row, Row, Row_Tuple, Row, F16, F16, F32, F16, F16_Tuple, F16, PassThrough, PassThrough, MultiplyAdd, GemmMNKPadding, 1, 256, 128, 64, 32, 8, 8, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGemmMultipleD_Xdl_CShuffle< Row, Row, Row_Tuple, Row, F16, F16, F32, F16, F16_Tuple, F16, PassThrough, PassThrough, MultiplyAdd, GemmMNKPadding, 1, 256, 64, 128, 32, 8, 2, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGemmMultipleD_Xdl_CShuffle< Row, Row, Row_Tuple, Row, F16, F16, F32, F16, F16_Tuple, F16, PassThrough, PassThrough, MultiplyAdd, GemmMNKPadding, 1, 256, 64, 128, 32, 8, 8, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, 1, 1, S<1, 32, 1, 8>, 8> + // clang-format on + >; + +void add_device_gemm_multiply_add_xdl_c_shuffle_f16_f16_f16_f16_f16_mk_kn_mn_mn_mn_instances( + std::vector>>& instances) +{ + add_device_operation_instances( + instances, + device_gemm_multiply_add_xdl_c_shuffle_f16_f16_f16_f16_f16_mk_kn_mn_mn_mn_instances{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/gemm_multiply_add/device_gemm_multiply_add_xdl_c_shuffle_f16_f16_f16_f16_f16_mk_nk_mn_mn_mn_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm_multiply_add/device_gemm_multiply_add_xdl_c_shuffle_f16_f16_f16_f16_f16_mk_nk_mn_mn_mn_instance.cpp new file mode 100644 index 0000000000000000000000000000000000000000..4be9b51aa751d493c79d628a634d76fd476c577b --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/gemm_multiply_add/device_gemm_multiply_add_xdl_c_shuffle_f16_f16_f16_f16_f16_mk_nk_mn_mn_mn_instance.cpp @@ -0,0 +1,82 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_gemm_multiple_d_xdl_cshuffle.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" + +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +using F16 = ck::half_t; +using F32 = float; +using F16_Tuple = ck::Tuple; + +using Row = ck::tensor_layout::gemm::RowMajor; +using Col = ck::tensor_layout::gemm::ColumnMajor; +using Row_Tuple = ck::Tuple; + +template +using S = ck::Sequence; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; +using MultiplyAdd = ck::tensor_operation::element_wise::MultiplyAdd; + +static constexpr auto GemmMNKPadding = ck::tensor_operation::device::GemmSpecialization::MNKPadding; + +using device_gemm_multiply_add_xdl_c_shuffle_f16_f16_f16_f16_f16_mk_nk_mn_mn_mn_instances = + std::tuple< + // clang-format off + // M/N/K padding + // N % 8 == 0 && K % 1 == 0 + //##############################| A| B| Ds| E| AData| BData| AccData| CShuffle| DsData| EData| A| B| CDE| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| + //##############################| Layout| Layout| Layout| Layout| Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Specialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| + //##############################| | | | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| + //##############################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceGemmMultipleD_Xdl_CShuffle< Row, Col, Row_Tuple, Row, F16, F16, F32, F16, F16_Tuple, F16, PassThrough, PassThrough, MultiplyAdd, GemmMNKPadding, 1, 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 4, 1, 64>, 1>, + DeviceGemmMultipleD_Xdl_CShuffle< Row, Col, Row_Tuple, Row, F16, F16, F32, F16, F16_Tuple, F16, PassThrough, PassThrough, MultiplyAdd, GemmMNKPadding, 1, 256, 128, 256, 32, 8, 8, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 4, 1, 64>, 1>, + DeviceGemmMultipleD_Xdl_CShuffle< Row, Col, Row_Tuple, Row, F16, F16, F32, F16, F16_Tuple, F16, PassThrough, PassThrough, MultiplyAdd, GemmMNKPadding, 1, 128, 128, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 2, 1, 64>, 1>, + DeviceGemmMultipleD_Xdl_CShuffle< Row, Col, Row_Tuple, Row, F16, F16, F32, F16, F16_Tuple, F16, PassThrough, PassThrough, MultiplyAdd, GemmMNKPadding, 1, 256, 128, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 4, 1, 64>, 1>, + DeviceGemmMultipleD_Xdl_CShuffle< Row, Col, Row_Tuple, Row, F16, F16, F32, F16, F16_Tuple, F16, PassThrough, PassThrough, MultiplyAdd, GemmMNKPadding, 1, 128, 128, 64, 32, 8, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 4, 1, 32>, 1>, + DeviceGemmMultipleD_Xdl_CShuffle< Row, Col, Row_Tuple, Row, F16, F16, F32, F16, F16_Tuple, F16, PassThrough, PassThrough, MultiplyAdd, GemmMNKPadding, 1, 128, 64, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 2, 1, 64>, 1>, + DeviceGemmMultipleD_Xdl_CShuffle< Row, Col, Row_Tuple, Row, F16, F16, F32, F16, F16_Tuple, F16, PassThrough, PassThrough, MultiplyAdd, GemmMNKPadding, 1, 64, 64, 64, 32, 8, 8, 32, 32, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 2, 1, 32>, 1>, + DeviceGemmMultipleD_Xdl_CShuffle< Row, Col, Row_Tuple, Row, F16, F16, F32, F16, F16_Tuple, F16, PassThrough, PassThrough, MultiplyAdd, GemmMNKPadding, 1, 256, 128, 64, 32, 8, 8, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 4, 1, 64>, 1>, + DeviceGemmMultipleD_Xdl_CShuffle< Row, Col, Row_Tuple, Row, F16, F16, F32, F16, F16_Tuple, F16, PassThrough, PassThrough, MultiplyAdd, GemmMNKPadding, 1, 256, 64, 128, 32, 8, 8, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 4, 1, 64>, 1>, + DeviceGemmMultipleD_Xdl_CShuffle< Row, Col, Row_Tuple, Row, F16, F16, F32, F16, F16_Tuple, F16, PassThrough, PassThrough, MultiplyAdd, GemmMNKPadding, 1, 128, 128, 32, 32, 8, 8, 32, 32, 2, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 4, 1, 32>, 1>, + DeviceGemmMultipleD_Xdl_CShuffle< Row, Col, Row_Tuple, Row, F16, F16, F32, F16, F16_Tuple, F16, PassThrough, PassThrough, MultiplyAdd, GemmMNKPadding, 1, 128, 32, 128, 32, 8, 8, 32, 32, 1, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 2, 1, 64>, 1>, + DeviceGemmMultipleD_Xdl_CShuffle< Row, Col, Row_Tuple, Row, F16, F16, F32, F16, F16_Tuple, F16, PassThrough, PassThrough, MultiplyAdd, GemmMNKPadding, 1, 64, 64, 32, 32, 8, 8, 32, 32, 2, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 2, 1, 32>, 1>, + DeviceGemmMultipleD_Xdl_CShuffle< Row, Col, Row_Tuple, Row, F16, F16, F32, F16, F16_Tuple, F16, PassThrough, PassThrough, MultiplyAdd, GemmMNKPadding, 1, 64, 32, 64, 32, 8, 8, 32, 32, 1, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 2, 1, 32>, 1> + + // clang-format on + >; + +void add_device_gemm_multiply_add_xdl_c_shuffle_f16_f16_f16_f16_f16_mk_nk_mn_mn_mn_instances( + std::vector>>& instances) +{ + add_device_operation_instances( + instances, + device_gemm_multiply_add_xdl_c_shuffle_f16_f16_f16_f16_f16_mk_nk_mn_mn_mn_instances{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/gemm_multiply_add/device_gemm_multiply_add_xdl_c_shuffle_f16_f8_f32_f32_f16_mk_kn_mn_mn_mn_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm_multiply_add/device_gemm_multiply_add_xdl_c_shuffle_f16_f8_f32_f32_f16_mk_kn_mn_mn_mn_instance.cpp new file mode 100644 index 0000000000000000000000000000000000000000..10c4453463615ab1a28fcc01f77f9dcc3d1b562e --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/gemm_multiply_add/device_gemm_multiply_add_xdl_c_shuffle_f16_f8_f32_f32_f16_mk_kn_mn_mn_mn_instance.cpp @@ -0,0 +1,100 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_gemm_multiple_d_xdl_cshuffle.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" + +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +using F8 = ck::f8_t; +using F16 = ck::half_t; +using F32 = float; +using F32_Tuple = ck::Tuple; + +using Row = ck::tensor_layout::gemm::RowMajor; +using Col = ck::tensor_layout::gemm::ColumnMajor; +using Row_Tuple = ck::Tuple; + +template +using S = ck::Sequence; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; +using MultiplyAdd = ck::tensor_operation::element_wise::MultiplyAdd; + +static constexpr auto GemmMNKPadding = ck::tensor_operation::device::GemmSpecialization::MNKPadding; + +using device_gemm_multiply_add_xdl_c_shuffle_f16_f8_f32_f32_f16_mk_kn_mn_mn_mn_generic_instances = + std::tuple< + // clang-format off + // M/N/K padding + //##############################| A| B| Ds| E| AData| BData| AccData| CShuffle| DsData| EData| A| B| CDE| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| + //##############################| Layout| Layout| Layout| Layout| Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Specialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| + //##############################| | | | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| + //##############################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceGemmMultipleD_Xdl_CShuffle< Row, Row, Row_Tuple, Row, F16, F8, F32, F32, F32_Tuple, F16, PassThrough, PassThrough, MultiplyAdd, GemmMNKPadding, 1, 256, 64, 128, 32, 8, 8, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 32, 1, 8>, 1> + // clang-format on + >; + +using device_gemm_multiply_add_xdl_c_shuffle_f16_f8_f32_f32_f16_mk_kn_mn_mn_mn_instances = + std::tuple< + // clang-format off + // M/N/K padding + //##############################| A| B| Ds| E| AData| BData| AccData| CShuffle| DsData| EData| A| B| CDE| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| + //##############################| Layout| Layout| Layout| Layout| Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Specialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| + //##############################| | | | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| + //##############################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceGemmMultipleD_Xdl_CShuffle< Row, Row, Row_Tuple, Row, F16, F8, F32, F32, F32_Tuple, F16, PassThrough, PassThrough, MultiplyAdd, GemmMNKPadding, 1, 256, 256, 128, 32, 8, 2, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGemmMultipleD_Xdl_CShuffle< Row, Row, Row_Tuple, Row, F16, F8, F32, F32, F32_Tuple, F16, PassThrough, PassThrough, MultiplyAdd, GemmMNKPadding, 1, 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGemmMultipleD_Xdl_CShuffle< Row, Row, Row_Tuple, Row, F16, F8, F32, F32, F32_Tuple, F16, PassThrough, PassThrough, MultiplyAdd, GemmMNKPadding, 1, 256, 128, 256, 32, 8, 2, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGemmMultipleD_Xdl_CShuffle< Row, Row, Row_Tuple, Row, F16, F8, F32, F32, F32_Tuple, F16, PassThrough, PassThrough, MultiplyAdd, GemmMNKPadding, 1, 256, 128, 256, 32, 8, 8, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGemmMultipleD_Xdl_CShuffle< Row, Row, Row_Tuple, Row, F16, F8, F32, F32, F32_Tuple, F16, PassThrough, PassThrough, MultiplyAdd, GemmMNKPadding, 1, 128, 128, 128, 32, 8, 2, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<1, 16, 1, 8>, 8>, + DeviceGemmMultipleD_Xdl_CShuffle< Row, Row, Row_Tuple, Row, F16, F8, F32, F32, F32_Tuple, F16, PassThrough, PassThrough, MultiplyAdd, GemmMNKPadding, 1, 128, 128, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 1, 1, 1, S<1, 16, 1, 8>, 8>, + DeviceGemmMultipleD_Xdl_CShuffle< Row, Row, Row_Tuple, Row, F16, F8, F32, F32, F32_Tuple, F16, PassThrough, PassThrough, MultiplyAdd, GemmMNKPadding, 1, 256, 128, 128, 32, 8, 2, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGemmMultipleD_Xdl_CShuffle< Row, Row, Row_Tuple, Row, F16, F8, F32, F32, F32_Tuple, F16, PassThrough, PassThrough, MultiplyAdd, GemmMNKPadding, 1, 256, 128, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGemmMultipleD_Xdl_CShuffle< Row, Row, Row_Tuple, Row, F16, F8, F32, F32, F32_Tuple, F16, PassThrough, PassThrough, MultiplyAdd, GemmMNKPadding, 1, 128, 128, 64, 32, 8, 2, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<8, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<1, 32, 1, 4>, 8>, + DeviceGemmMultipleD_Xdl_CShuffle< Row, Row, Row_Tuple, Row, F16, F8, F32, F32, F32_Tuple, F16, PassThrough, PassThrough, MultiplyAdd, GemmMNKPadding, 1, 128, 128, 64, 32, 8, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, 1, 1, S<1, 32, 1, 4>, 8>, + DeviceGemmMultipleD_Xdl_CShuffle< Row, Row, Row_Tuple, Row, F16, F8, F32, F32, F32_Tuple, F16, PassThrough, PassThrough, MultiplyAdd, GemmMNKPadding, 1, 128, 64, 128, 32, 8, 2, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<1, 16, 1, 8>, 8>, + DeviceGemmMultipleD_Xdl_CShuffle< Row, Row, Row_Tuple, Row, F16, F8, F32, F32, F32_Tuple, F16, PassThrough, PassThrough, MultiplyAdd, GemmMNKPadding, 1, 128, 64, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 1, 1, 1, S<1, 16, 1, 8>, 8>, + DeviceGemmMultipleD_Xdl_CShuffle< Row, Row, Row_Tuple, Row, F16, F8, F32, F32, F32_Tuple, F16, PassThrough, PassThrough, MultiplyAdd, GemmMNKPadding, 1, 256, 128, 64, 32, 8, 2, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<16,16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGemmMultipleD_Xdl_CShuffle< Row, Row, Row_Tuple, Row, F16, F8, F32, F32, F32_Tuple, F16, PassThrough, PassThrough, MultiplyAdd, GemmMNKPadding, 1, 256, 128, 64, 32, 8, 8, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGemmMultipleD_Xdl_CShuffle< Row, Row, Row_Tuple, Row, F16, F8, F32, F32, F32_Tuple, F16, PassThrough, PassThrough, MultiplyAdd, GemmMNKPadding, 1, 256, 64, 128, 32, 8, 2, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGemmMultipleD_Xdl_CShuffle< Row, Row, Row_Tuple, Row, F16, F8, F32, F32, F32_Tuple, F16, PassThrough, PassThrough, MultiplyAdd, GemmMNKPadding, 1, 256, 64, 128, 32, 8, 8, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, 1, 1, S<1, 32, 1, 8>, 8> + // clang-format on + >; + +void add_device_gemm_multiply_add_xdl_c_shuffle_f16_f8_f32_f32_f16_mk_kn_mn_mn_mn_instances( + std::vector>>& instances) +{ + add_device_operation_instances( + instances, + device_gemm_multiply_add_xdl_c_shuffle_f16_f8_f32_f32_f16_mk_kn_mn_mn_mn_generic_instances{}); + + add_device_operation_instances( + instances, + device_gemm_multiply_add_xdl_c_shuffle_f16_f8_f32_f32_f16_mk_kn_mn_mn_mn_instances{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/gemm_multiply_add/device_gemm_multiply_add_xdl_c_shuffle_f16_f8_f32_f32_f16_mk_nk_mn_mn_mn_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm_multiply_add/device_gemm_multiply_add_xdl_c_shuffle_f16_f8_f32_f32_f16_mk_nk_mn_mn_mn_instance.cpp new file mode 100644 index 0000000000000000000000000000000000000000..71a3e424969fbc22405e34b71b2a312af1b3d109 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/gemm_multiply_add/device_gemm_multiply_add_xdl_c_shuffle_f16_f8_f32_f32_f16_mk_nk_mn_mn_mn_instance.cpp @@ -0,0 +1,99 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_gemm_multiple_d_xdl_cshuffle.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" + +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +using F8 = ck::f8_t; +using F16 = ck::half_t; +using F32 = float; +using F32_Tuple = ck::Tuple; + +using Row = ck::tensor_layout::gemm::RowMajor; +using Col = ck::tensor_layout::gemm::ColumnMajor; +using Row_Tuple = ck::Tuple; + +template +using S = ck::Sequence; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; +using MultiplyAdd = ck::tensor_operation::element_wise::MultiplyAdd; + +static constexpr auto GemmMNKPadding = ck::tensor_operation::device::GemmSpecialization::MNKPadding; + +using device_gemm_multiply_add_xdl_c_shuffle_f16_f8_f32_f32_f16_mk_nk_mn_mn_mn_generic_instances = + std::tuple< + // clang-format off + // M/N/K padding + // N % 8 == 0 && K % 1 == 0 + //##############################| A| B| Ds| E| AData| BData| AccData| CShuffle| DsData| EData| A| B| CDE| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| + //##############################| Layout| Layout| Layout| Layout| Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Specialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| + //##############################| | | | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| + //##############################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceGemmMultipleD_Xdl_CShuffle< Row, Col, Row_Tuple, Row, F16, F8, F32, F32, F32_Tuple, F16, PassThrough, PassThrough, MultiplyAdd, GemmMNKPadding, 1, 64, 32, 64, 32, 8, 8, 32, 32, 1, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 8, 1, 1, 1, S<1, 2, 1, 32>, 1> + // clang-format on + >; + +using device_gemm_multiply_add_xdl_c_shuffle_f16_f8_f32_f32_f16_mk_nk_mn_mn_mn_instances = + std::tuple< + // clang-format off + // M/N/K padding + // N % 8 == 0 && K % 1 == 0 + //##############################| A| B| Ds| E| AData| BData| AccData| CShuffle| DsData| EData| A| B| CDE| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| + //##############################| Layout| Layout| Layout| Layout| Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Specialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| + //##############################| | | | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| + //##############################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceGemmMultipleD_Xdl_CShuffle< Row, Col, Row_Tuple, Row, F16, F8, F32, F32, F32_Tuple, F16, PassThrough, PassThrough, MultiplyAdd, GemmMNKPadding, 1, 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 4, 1, 64>, 1>, + DeviceGemmMultipleD_Xdl_CShuffle< Row, Col, Row_Tuple, Row, F16, F8, F32, F32, F32_Tuple, F16, PassThrough, PassThrough, MultiplyAdd, GemmMNKPadding, 1, 256, 128, 256, 32, 8, 8, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 4, 1, 64>, 1>, + DeviceGemmMultipleD_Xdl_CShuffle< Row, Col, Row_Tuple, Row, F16, F8, F32, F32, F32_Tuple, F16, PassThrough, PassThrough, MultiplyAdd, GemmMNKPadding, 1, 128, 128, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 2, 1, 64>, 1>, + DeviceGemmMultipleD_Xdl_CShuffle< Row, Col, Row_Tuple, Row, F16, F8, F32, F32, F32_Tuple, F16, PassThrough, PassThrough, MultiplyAdd, GemmMNKPadding, 1, 256, 128, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 4, 1, 64>, 1>, + DeviceGemmMultipleD_Xdl_CShuffle< Row, Col, Row_Tuple, Row, F16, F8, F32, F32, F32_Tuple, F16, PassThrough, PassThrough, MultiplyAdd, GemmMNKPadding, 1, 128, 128, 64, 32, 8, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 4, 1, 32>, 1>, + DeviceGemmMultipleD_Xdl_CShuffle< Row, Col, Row_Tuple, Row, F16, F8, F32, F32, F32_Tuple, F16, PassThrough, PassThrough, MultiplyAdd, GemmMNKPadding, 1, 128, 64, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 2, 1, 64>, 1>, + DeviceGemmMultipleD_Xdl_CShuffle< Row, Col, Row_Tuple, Row, F16, F8, F32, F32, F32_Tuple, F16, PassThrough, PassThrough, MultiplyAdd, GemmMNKPadding, 1, 64, 64, 64, 32, 8, 8, 32, 32, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 2, 1, 32>, 1>, + DeviceGemmMultipleD_Xdl_CShuffle< Row, Col, Row_Tuple, Row, F16, F8, F32, F32, F32_Tuple, F16, PassThrough, PassThrough, MultiplyAdd, GemmMNKPadding, 1, 256, 128, 64, 32, 8, 8, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 4, 1, 64>, 1>, + DeviceGemmMultipleD_Xdl_CShuffle< Row, Col, Row_Tuple, Row, F16, F8, F32, F32, F32_Tuple, F16, PassThrough, PassThrough, MultiplyAdd, GemmMNKPadding, 1, 256, 64, 128, 32, 8, 8, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 4, 1, 64>, 1>, + DeviceGemmMultipleD_Xdl_CShuffle< Row, Col, Row_Tuple, Row, F16, F8, F32, F32, F32_Tuple, F16, PassThrough, PassThrough, MultiplyAdd, GemmMNKPadding, 1, 128, 128, 32, 32, 8, 8, 32, 32, 2, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 4, 1, 32>, 1>, + DeviceGemmMultipleD_Xdl_CShuffle< Row, Col, Row_Tuple, Row, F16, F8, F32, F32, F32_Tuple, F16, PassThrough, PassThrough, MultiplyAdd, GemmMNKPadding, 1, 128, 32, 128, 32, 8, 8, 32, 32, 1, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 2, 1, 64>, 1>, + DeviceGemmMultipleD_Xdl_CShuffle< Row, Col, Row_Tuple, Row, F16, F8, F32, F32, F32_Tuple, F16, PassThrough, PassThrough, MultiplyAdd, GemmMNKPadding, 1, 64, 64, 32, 32, 8, 8, 32, 32, 2, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 2, 1, 32>, 1>, + DeviceGemmMultipleD_Xdl_CShuffle< Row, Col, Row_Tuple, Row, F16, F8, F32, F32, F32_Tuple, F16, PassThrough, PassThrough, MultiplyAdd, GemmMNKPadding, 1, 64, 32, 64, 32, 8, 8, 32, 32, 1, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 2, 1, 32>, 1> + // clang-format on + >; + +void add_device_gemm_multiply_add_xdl_c_shuffle_f16_f8_f32_f32_f16_mk_nk_mn_mn_mn_instances( + std::vector>>& instances) +{ + add_device_operation_instances( + instances, + device_gemm_multiply_add_xdl_c_shuffle_f16_f8_f32_f32_f16_mk_nk_mn_mn_mn_generic_instances{}); + + add_device_operation_instances( + instances, + device_gemm_multiply_add_xdl_c_shuffle_f16_f8_f32_f32_f16_mk_nk_mn_mn_mn_instances{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/gemm_reduce/device_gemm_reduce_xdl_cshuffle_f16_f16_f16_f32_f32_km_kn_mn_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm_reduce/device_gemm_reduce_xdl_cshuffle_f16_f16_f16_f32_f32_km_kn_mn_instance.cpp index 59e2b2da8641b9fb094750077d760e9f611d9c98..f0e7b6ab43c51abeb4e79e9e30b104bf05f82a07 100644 --- a/library/src/tensor_operation_instance/gpu/gemm_reduce/device_gemm_reduce_xdl_cshuffle_f16_f16_f16_f32_f32_km_kn_mn_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/gemm_reduce/device_gemm_reduce_xdl_cshuffle_f16_f16_f16_f32_f32_km_kn_mn_instance.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #include diff --git a/library/src/tensor_operation_instance/gpu/gemm_reduce/device_gemm_reduce_xdl_cshuffle_f16_f16_f16_f32_f32_km_nk_mn_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm_reduce/device_gemm_reduce_xdl_cshuffle_f16_f16_f16_f32_f32_km_nk_mn_instance.cpp index bb09bf8b8e82f6c2056881f37b9e65c73657f0e5..56815b9ac5fa4d89b649089c73fbfeddc985cc8e 100644 --- a/library/src/tensor_operation_instance/gpu/gemm_reduce/device_gemm_reduce_xdl_cshuffle_f16_f16_f16_f32_f32_km_nk_mn_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/gemm_reduce/device_gemm_reduce_xdl_cshuffle_f16_f16_f16_f32_f32_km_nk_mn_instance.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #include diff --git a/library/src/tensor_operation_instance/gpu/gemm_reduce/device_gemm_reduce_xdl_cshuffle_f16_f16_f16_f32_f32_mk_kn_mn_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm_reduce/device_gemm_reduce_xdl_cshuffle_f16_f16_f16_f32_f32_mk_kn_mn_instance.cpp index 0a3b566de64505adb19f75ffe5e126cfed5da438..e66d46a26995ec521df8fe15a7fd01277ebaff0d 100644 --- a/library/src/tensor_operation_instance/gpu/gemm_reduce/device_gemm_reduce_xdl_cshuffle_f16_f16_f16_f32_f32_mk_kn_mn_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/gemm_reduce/device_gemm_reduce_xdl_cshuffle_f16_f16_f16_f32_f32_mk_kn_mn_instance.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #include diff --git a/library/src/tensor_operation_instance/gpu/gemm_reduce/device_gemm_reduce_xdl_cshuffle_f16_f16_f16_f32_f32_mk_nk_mn_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm_reduce/device_gemm_reduce_xdl_cshuffle_f16_f16_f16_f32_f32_mk_nk_mn_instance.cpp index 2b17e47b1cce153654d74ca7497b5c04117071de..fb1dfac69cd525ee7b0caa3d3d7b732ba5042bd8 100644 --- a/library/src/tensor_operation_instance/gpu/gemm_reduce/device_gemm_reduce_xdl_cshuffle_f16_f16_f16_f32_f32_mk_nk_mn_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/gemm_reduce/device_gemm_reduce_xdl_cshuffle_f16_f16_f16_f32_f32_mk_nk_mn_instance.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #include diff --git a/library/src/tensor_operation_instance/gpu/gemm_splitk/CMakeLists.txt b/library/src/tensor_operation_instance/gpu/gemm_splitk/CMakeLists.txt index 6b336227465ff2602e13e94c9889a7eb0f8665cc..89dfa8f2edb3540d41566400f3e3bcc100d87927 100644 --- a/library/src/tensor_operation_instance/gpu/gemm_splitk/CMakeLists.txt +++ b/library/src/tensor_operation_instance/gpu/gemm_splitk/CMakeLists.txt @@ -1,10 +1,28 @@ -add_instance_library(device_gemm_splitk_instance - device_gemm_xdl_splitk_f32_f32_f32_mk_kn_mn_instance.cpp - device_gemm_xdl_splitk_f32_f32_f32_mk_nk_mn_instance.cpp - device_gemm_xdl_splitk_f32_f32_f32_km_kn_mn_instance.cpp - device_gemm_xdl_splitk_f32_f32_f32_km_nk_mn_instance.cpp - device_gemm_xdl_splitk_f16_f16_f16_mk_kn_mn_instance.cpp - device_gemm_xdl_splitk_f16_f16_f16_mk_nk_mn_instance.cpp - device_gemm_xdl_splitk_f16_f16_f16_km_kn_mn_instance.cpp - device_gemm_xdl_splitk_f16_f16_f16_km_nk_mn_instance.cpp -) +set(GEMM_SPLITK_INSTANCES) + +if(DTYPES MATCHES "fp32" OR NOT DEFINED DTYPES) + list(APPEND GEMM_SPLITK_INSTANCES device_gemm_xdl_splitk_f32_f32_f32_mk_kn_mn_instance.cpp) + list(APPEND GEMM_SPLITK_INSTANCES device_gemm_xdl_splitk_f32_f32_f32_mk_nk_mn_instance.cpp) + list(APPEND GEMM_SPLITK_INSTANCES device_gemm_xdl_splitk_f32_f32_f32_km_kn_mn_instance.cpp) + list(APPEND GEMM_SPLITK_INSTANCES device_gemm_xdl_splitk_f32_f32_f32_km_nk_mn_instance.cpp) +endif() + +if(DTYPES MATCHES "fp16" OR NOT DEFINED DTYPES) + list(APPEND GEMM_SPLITK_INSTANCES device_gemm_xdl_splitk_f16_f16_f16_mk_kn_mn_instance.cpp) + list(APPEND GEMM_SPLITK_INSTANCES device_gemm_xdl_splitk_f16_f16_f16_mk_nk_mn_instance.cpp) + list(APPEND GEMM_SPLITK_INSTANCES device_gemm_xdl_splitk_f16_f16_f16_km_kn_mn_instance.cpp) + list(APPEND GEMM_SPLITK_INSTANCES device_gemm_xdl_splitk_f16_f16_f16_km_nk_mn_instance.cpp) +endif() + +if(DTYPES MATCHES "fp16" OR DTYPES MATCHES "fp8" OR NOT DEFINED DTYPES) + list(APPEND GEMM_SPLITK_INSTANCES device_gemm_xdl_splitk_f8_f16_f16_mk_kn_mn_instance.cpp) + list(APPEND GEMM_SPLITK_INSTANCES device_gemm_xdl_splitk_f8_f16_f16_mk_nk_mn_instance.cpp) + list(APPEND GEMM_SPLITK_INSTANCES device_gemm_xdl_splitk_f8_f16_f16_km_kn_mn_instance.cpp) + list(APPEND GEMM_SPLITK_INSTANCES device_gemm_xdl_splitk_f8_f16_f16_km_nk_mn_instance.cpp) + list(APPEND GEMM_SPLITK_INSTANCES device_gemm_xdl_splitk_f16_f8_f16_mk_kn_mn_instance.cpp) + list(APPEND GEMM_SPLITK_INSTANCES device_gemm_xdl_splitk_f16_f8_f16_mk_nk_mn_instance.cpp) + list(APPEND GEMM_SPLITK_INSTANCES device_gemm_xdl_splitk_f16_f8_f16_km_kn_mn_instance.cpp) + list(APPEND GEMM_SPLITK_INSTANCES device_gemm_xdl_splitk_f16_f8_f16_km_nk_mn_instance.cpp) +endif() + +add_instance_library(device_gemm_splitk_instance ${GEMM_SPLITK_INSTANCES}) diff --git a/library/src/tensor_operation_instance/gpu/gemm_splitk/device_gemm_xdl_splitk_f16_f16_f16_km_kn_mn_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm_splitk/device_gemm_xdl_splitk_f16_f16_f16_km_kn_mn_instance.cpp index e178d3b0adeb563c36d6a8f1f3936d048df23f75..fed2cbbfb980ab27e8810b643874fc9b308b21b7 100644 --- a/library/src/tensor_operation_instance/gpu/gemm_splitk/device_gemm_xdl_splitk_f16_f16_f16_km_kn_mn_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/gemm_splitk/device_gemm_xdl_splitk_f16_f16_f16_km_kn_mn_instance.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #include diff --git a/library/src/tensor_operation_instance/gpu/gemm_splitk/device_gemm_xdl_splitk_f16_f16_f16_km_nk_mn_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm_splitk/device_gemm_xdl_splitk_f16_f16_f16_km_nk_mn_instance.cpp index 52be9fe709bd710f3e05f381e342deef5c7affae..44ac4c08cd97881f98f0c36a9bd5d42cd5078327 100644 --- a/library/src/tensor_operation_instance/gpu/gemm_splitk/device_gemm_xdl_splitk_f16_f16_f16_km_nk_mn_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/gemm_splitk/device_gemm_xdl_splitk_f16_f16_f16_km_nk_mn_instance.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #include diff --git a/library/src/tensor_operation_instance/gpu/gemm_splitk/device_gemm_xdl_splitk_f16_f16_f16_mk_kn_mn_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm_splitk/device_gemm_xdl_splitk_f16_f16_f16_mk_kn_mn_instance.cpp index c4680db831790d3e86f676bd14a07fc266deb9f6..218d6e0c2a3b6c17886514231ed4d8857ed338d5 100644 --- a/library/src/tensor_operation_instance/gpu/gemm_splitk/device_gemm_xdl_splitk_f16_f16_f16_mk_kn_mn_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/gemm_splitk/device_gemm_xdl_splitk_f16_f16_f16_mk_kn_mn_instance.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #include @@ -29,6 +29,17 @@ using PassThrough = ck::tensor_operation::element_wise::PassThrough; // static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default; static constexpr auto GemmMNPadding = ck::tensor_operation::device::GemmSpecialization::MNPadding; +using device_gemm_xdl_splitk_f16_f16_f16_mk_kn_mn_generic_instances = std::tuple< + // clang-format off + //#########################|AData| BData| CData| AccData| ALayout| BLayout| CLayout| A| B| C| GEMM| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| + //#########################| Type| Type| Type| Type| | | | Elementwise| Elementwise| Elementwise|Specialization| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MXdlPerWave_MWaveMPerXdl| ScalarPerVector| + //#########################| | | | | | | | Operation| Operation| Operation| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NXdlPerWave_NWaveNPerXdl| _NWaveNPerXdl| + //#########################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceGemmXdlSplitKCShuffle< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 128, 128, 128, 4, 8, 32, 32, 4, 2, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 1, 8, true, S<1, 4, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 2, 8, true, 1, 1, S<1, 16, 1, 8>, 2>, + DeviceGemmXdlSplitKCShuffle< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 64, 32, 32, 4, 8, 32, 32, 1, 1, S<1, 2, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 1, 8, true, S<1, 4, 16, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 2, 8, true, 1, 1, S<1, 16, 1, 4>, 2> + // clang-format on + >; + // Compilation parameters for a[m, k] * b[k, n] = c[m, n] using device_gemm_xdl_splitk_f16_f16_f16_mk_kn_mn_instances = std::tuple< // clang-format off @@ -36,22 +47,43 @@ using device_gemm_xdl_splitk_f16_f16_f16_mk_kn_mn_instances = std::tuple< //#########################| Type| Type| Type| Type| | | | Elementwise| Elementwise| Elementwise|Specialization| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MXdlPerWave_MWaveMPerXdl| ScalarPerVector| //#########################| | | | | | | | Operation| Operation| Operation| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NXdlPerWave_NWaveNPerXdl| _NWaveNPerXdl| //#########################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | - DeviceGemmXdlSplitKCShuffle< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 256, 256, 128, 4, 8, 32, 32, 4, 2, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 64, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 2, 8, true, 1, 1, S<1, 32, 1, 8>, 8>, - DeviceGemmXdlSplitKCShuffle< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 256, 128, 256, 4, 8, 32, 32, 2, 4, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 64, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 4, 8, true, 1, 1, S<1, 32, 1, 8>, 8>, - DeviceGemmXdlSplitKCShuffle< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 128, 128, 128, 4, 8, 32, 32, 4, 2, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 4, 8, true, 1, 1, S<1, 16, 1, 8>, 8>, - DeviceGemmXdlSplitKCShuffle< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 256, 64, 192, 4, 8, 32, 32, 1, 3, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 48, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 2, 8, true, 1, 1, S<1, 32, 1, 8>, 8>, - DeviceGemmXdlSplitKCShuffle< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 256, 192, 64, 4, 8, 32, 32, 3, 1, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 2, 8, true, 1, 1, S<1, 32, 1, 8>, 8>, - DeviceGemmXdlSplitKCShuffle< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 256, 128, 128, 4, 8, 32, 32, 2, 2, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 64, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 2, 8, true, 1, 1, S<1, 32, 1, 8>, 8>, - DeviceGemmXdlSplitKCShuffle< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 128, 128, 64, 4, 8, 32, 32, 2, 2, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 2, 8, true, 1, 1, S<1, 32, 1, 4>, 8>, - DeviceGemmXdlSplitKCShuffle< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 128, 64, 128, 4, 8, 32, 32, 2, 2, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 4, 8, true, 1, 1, S<1, 16, 1, 8>, 8>, - DeviceGemmXdlSplitKCShuffle< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 256, 128, 64, 4, 8, 32, 32, 2, 1, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 64, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 1, 8, true, 1, 1, S<1, 16, 1, 4>, 8>, - DeviceGemmXdlSplitKCShuffle< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 256, 64, 128, 4, 8, 32, 32, 1, 2, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 64, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 2, 8, true, 1, 1, S<1, 32, 1, 8>, 8>, - DeviceGemmXdlSplitKCShuffle< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 128, 32, 192, 4, 8, 32, 32, 1, 3, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 24, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 8, 8, true, 1, 1, S<1, 16, 1, 8>, 8>, - DeviceGemmXdlSplitKCShuffle< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 128, 192, 32, 4, 8, 32, 32, 3, 1, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 1, 8, true, 1, 1, S<1, 32, 1, 4>, 8>, - DeviceGemmXdlSplitKCShuffle< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 128, 32, 64, 4, 8, 32, 32, 1, 1, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 2, 8, true, 1, 1, S<1, 16, 1, 8>, 8>, - DeviceGemmXdlSplitKCShuffle< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 128, 64, 32, 4, 8, 32, 32, 1, 1, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 1, 8, true, 1, 1, S<1, 32, 1, 4>, 8>, - DeviceGemmXdlSplitKCShuffle< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 128, 32, 128, 4, 8, 32, 32, 1, 2, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 4, 8, true, 1, 1, S<1, 16, 1, 8>, 8>, - DeviceGemmXdlSplitKCShuffle< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 128, 128, 32, 4, 8, 32, 32, 2, 1, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 1, 8, true, 1, 1, S<1, 32, 1, 4>, 8> + //PipelineVersion::v1 + DeviceGemmXdlSplitKCShuffle< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 256, 256, 128, 4, 8, 32, 32, 4, 2, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 64, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 2, 8, true, 1, 1, S<1, 32, 1, 8>, 8, F16, PipelineVersion::v1>, + DeviceGemmXdlSplitKCShuffle< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 256, 128, 256, 4, 8, 32, 32, 2, 4, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 64, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 4, 8, true, 1, 1, S<1, 32, 1, 8>, 8, F16, PipelineVersion::v1>, + DeviceGemmXdlSplitKCShuffle< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 128, 128, 128, 4, 8, 32, 32, 4, 2, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 4, 8, true, 1, 1, S<1, 16, 1, 8>, 8, F16, PipelineVersion::v1>, + DeviceGemmXdlSplitKCShuffle< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 256, 64, 192, 4, 8, 32, 32, 1, 3, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 48, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 2, 8, true, 1, 1, S<1, 32, 1, 8>, 8, F16, PipelineVersion::v1>, + DeviceGemmXdlSplitKCShuffle< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 256, 192, 64, 4, 8, 32, 32, 3, 1, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 2, 8, true, 1, 1, S<1, 32, 1, 8>, 8, F16, PipelineVersion::v1>, + DeviceGemmXdlSplitKCShuffle< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 256, 128, 128, 4, 8, 32, 32, 2, 2, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 64, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 2, 8, true, 1, 1, S<1, 32, 1, 8>, 8, F16, PipelineVersion::v1>, + DeviceGemmXdlSplitKCShuffle< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 128, 128, 64, 4, 8, 32, 32, 2, 2, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 2, 8, true, 1, 1, S<1, 32, 1, 4>, 8, F16, PipelineVersion::v1>, + DeviceGemmXdlSplitKCShuffle< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 128, 64, 128, 4, 8, 32, 32, 2, 2, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 4, 8, true, 1, 1, S<1, 16, 1, 8>, 8, F16, PipelineVersion::v1>, + DeviceGemmXdlSplitKCShuffle< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 256, 128, 64, 4, 8, 32, 32, 2, 1, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 64, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 1, 8, true, 1, 1, S<1, 16, 1, 4>, 8, F16, PipelineVersion::v1>, + DeviceGemmXdlSplitKCShuffle< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 256, 64, 128, 4, 8, 32, 32, 1, 2, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 64, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 2, 8, true, 1, 1, S<1, 32, 1, 8>, 8, F16, PipelineVersion::v1>, + DeviceGemmXdlSplitKCShuffle< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 128, 32, 192, 4, 8, 32, 32, 1, 3, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 24, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 8, 8, true, 1, 1, S<1, 16, 1, 8>, 8, F16, PipelineVersion::v1>, + DeviceGemmXdlSplitKCShuffle< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 128, 192, 32, 4, 8, 32, 32, 3, 1, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 1, 8, true, 1, 1, S<1, 32, 1, 4>, 8, F16, PipelineVersion::v1>, + DeviceGemmXdlSplitKCShuffle< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 128, 32, 64, 4, 8, 32, 32, 1, 1, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 2, 8, true, 1, 1, S<1, 16, 1, 8>, 8, F16, PipelineVersion::v1>, + DeviceGemmXdlSplitKCShuffle< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 128, 64, 32, 4, 8, 32, 32, 1, 1, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 1, 8, true, 1, 1, S<1, 32, 1, 4>, 8, F16, PipelineVersion::v1>, + DeviceGemmXdlSplitKCShuffle< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 128, 32, 128, 4, 8, 32, 32, 1, 2, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 4, 8, true, 1, 1, S<1, 16, 1, 8>, 8, F16, PipelineVersion::v1>, + DeviceGemmXdlSplitKCShuffle< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 128, 128, 32, 4, 8, 32, 32, 2, 1, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 1, 8, true, 1, 1, S<1, 32, 1, 4>, 8, F16, PipelineVersion::v1>, + DeviceGemmXdlSplitKCShuffle< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 64, 32, 32, 4, 8, 32, 32, 1, 1, S<1, 2, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 16, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 2, 8, true, 1, 1, S<1, 16, 1, 4>, 8, F16, PipelineVersion::v1>, + DeviceGemmXdlSplitKCShuffle< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 64, 16, 32, 4, 8, 16, 16, 1, 2, S<1, 4, 16, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 16, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 2, 8, true, 1, 1, S<1, 16, 1, 4>, 4, F16, PipelineVersion::v1>, + + //PipelineVersion::v2 + DeviceGemmXdlSplitKCShuffle< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 256, 256, 128, 4, 8, 32, 32, 4, 2, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 64, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 2, 8, true, 1, 1, S<1, 32, 1, 8>, 8, F16, PipelineVersion::v2>, + DeviceGemmXdlSplitKCShuffle< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 256, 128, 256, 4, 8, 32, 32, 2, 4, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 64, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 4, 8, true, 1, 1, S<1, 32, 1, 8>, 8, F16, PipelineVersion::v2>, + DeviceGemmXdlSplitKCShuffle< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 128, 128, 128, 4, 8, 32, 32, 4, 2, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 4, 8, true, 1, 1, S<1, 16, 1, 8>, 8, F16, PipelineVersion::v2>, + DeviceGemmXdlSplitKCShuffle< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 256, 64, 192, 4, 8, 32, 32, 1, 3, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 48, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 2, 8, true, 1, 1, S<1, 32, 1, 8>, 8, F16, PipelineVersion::v2>, + DeviceGemmXdlSplitKCShuffle< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 256, 192, 64, 4, 8, 32, 32, 3, 1, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 2, 8, true, 1, 1, S<1, 32, 1, 8>, 8, F16, PipelineVersion::v2>, + DeviceGemmXdlSplitKCShuffle< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 256, 128, 128, 4, 8, 32, 32, 2, 2, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 64, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 2, 8, true, 1, 1, S<1, 32, 1, 8>, 8, F16, PipelineVersion::v2>, + DeviceGemmXdlSplitKCShuffle< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 128, 128, 64, 4, 8, 32, 32, 2, 2, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 2, 8, true, 1, 1, S<1, 32, 1, 4>, 8, F16, PipelineVersion::v2>, + DeviceGemmXdlSplitKCShuffle< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 128, 64, 128, 4, 8, 32, 32, 2, 2, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 4, 8, true, 1, 1, S<1, 16, 1, 8>, 8, F16, PipelineVersion::v2>, + DeviceGemmXdlSplitKCShuffle< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 256, 128, 64, 4, 8, 32, 32, 2, 1, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 64, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 1, 8, true, 1, 1, S<1, 16, 1, 4>, 8, F16, PipelineVersion::v2>, + DeviceGemmXdlSplitKCShuffle< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 256, 64, 128, 4, 8, 32, 32, 1, 2, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 64, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 2, 8, true, 1, 1, S<1, 32, 1, 8>, 8, F16, PipelineVersion::v2>, + DeviceGemmXdlSplitKCShuffle< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 128, 32, 192, 4, 8, 32, 32, 1, 3, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 24, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 8, 8, true, 1, 1, S<1, 16, 1, 8>, 8, F16, PipelineVersion::v2>, + DeviceGemmXdlSplitKCShuffle< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 128, 192, 32, 4, 8, 32, 32, 3, 1, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 1, 8, true, 1, 1, S<1, 32, 1, 4>, 8, F16, PipelineVersion::v2>, + DeviceGemmXdlSplitKCShuffle< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 128, 32, 64, 4, 8, 32, 32, 1, 1, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 2, 8, true, 1, 1, S<1, 16, 1, 8>, 8, F16, PipelineVersion::v2>, + DeviceGemmXdlSplitKCShuffle< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 128, 64, 32, 4, 8, 32, 32, 1, 1, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 1, 8, true, 1, 1, S<1, 32, 1, 4>, 8, F16, PipelineVersion::v2>, + DeviceGemmXdlSplitKCShuffle< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 128, 32, 128, 4, 8, 32, 32, 1, 2, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 4, 8, true, 1, 1, S<1, 16, 1, 8>, 8, F16, PipelineVersion::v2>, + DeviceGemmXdlSplitKCShuffle< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 128, 128, 32, 4, 8, 32, 32, 2, 1, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 1, 8, true, 1, 1, S<1, 32, 1, 4>, 8, F16, PipelineVersion::v2> // clang-format on >; @@ -60,6 +92,8 @@ void add_device_gemm_xdl_splitk_f16_f16_f16_mk_kn_mn_instances( DeviceGemmSplitK>>& instances) { + add_device_operation_instances(instances, + device_gemm_xdl_splitk_f16_f16_f16_mk_kn_mn_generic_instances{}); add_device_operation_instances(instances, device_gemm_xdl_splitk_f16_f16_f16_mk_kn_mn_instances{}); } diff --git a/library/src/tensor_operation_instance/gpu/gemm_splitk/device_gemm_xdl_splitk_f16_f16_f16_mk_nk_mn_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm_splitk/device_gemm_xdl_splitk_f16_f16_f16_mk_nk_mn_instance.cpp index 7fc35c4198eb49bf2625f6cfa8eb80a001fdc086..b87f7ff309c91ec70a2c15a2fac56781f9fe1690 100644 --- a/library/src/tensor_operation_instance/gpu/gemm_splitk/device_gemm_xdl_splitk_f16_f16_f16_mk_nk_mn_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/gemm_splitk/device_gemm_xdl_splitk_f16_f16_f16_mk_nk_mn_instance.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #include @@ -28,6 +28,17 @@ using PassThrough = ck::tensor_operation::element_wise::PassThrough; static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default; +using device_gemm_xdl_splitk_f16_f16_f16_mk_nk_mn_generic_instances = std::tuple< + // clang-format off + //#########################|AData| BData| CData| AccData| ALayout| BLayout| CLayout| A| B| C| GEMM| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| + //#########################| Type| Type| Type| Type| | | | Elementwise| Elementwise| Elementwise|Specialization| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MXdlPerWave_MWaveMPerXdl| ScalarPerVector| + //#########################| | | | | | | | Operation| Operation| Operation| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NXdlPerWave_NWaveNPerXdl| _NWaveNPerXdl| + //#########################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceGemmXdlSplitKCShuffle< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 128, 128, 4, 8, 32, 32, 2, 2, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 1, 8, true, S<1, 4, 64, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 3, 1, 8, true, 1, 1, S<1, 32, 1, 8>, 2, F16, PipelineVersion::v1>, + DeviceGemmXdlSplitKCShuffle< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 64, 32, 64, 4, 8, 32, 32, 1, 2, S<1, 4, 16, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 1, 8, true, S<1, 4, 16, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 3, 1, 8, true, 1, 1, S<1, 16, 1, 4>, 2, F16, PipelineVersion::v1> + // clang-format on + >; + // Compilation parameters for a[m, k] * b[k, n] = c[m, n] using device_gemm_xdl_splitk_f16_f16_f16_mk_nk_mn_instances = std::tuple< // clang-format off @@ -35,19 +46,35 @@ using device_gemm_xdl_splitk_f16_f16_f16_mk_nk_mn_instances = std::tuple< //#########################| Type| Type| Type| Type| | | | Elementwise| Elementwise| Elementwise|Specialization| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MXdlPerWave_MWaveMPerXdl| ScalarPerVector| //#########################| | | | | | | | Operation| Operation| Operation| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NXdlPerWave_NWaveNPerXdl| _NWaveNPerXdl| //#########################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | - DeviceGemmXdlSplitKCShuffle< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 256, 128, 4, 8, 32, 32, 4, 2, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 64, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 3, 8, 8, true, 1, 1, S<1, 32, 1, 8>, 8>, - DeviceGemmXdlSplitKCShuffle< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 128, 256, 4, 8, 32, 32, 2, 4, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 64, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 3, 8, 8, true, 1, 1, S<1, 32, 1, 8>, 8>, - DeviceGemmXdlSplitKCShuffle< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 128, 128, 128, 4, 8, 32, 32, 4, 2, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 3, 8, 8, true, 1, 1, S<1, 16, 1, 8>, 8>, - DeviceGemmXdlSplitKCShuffle< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 128, 128, 4, 8, 32, 32, 2, 2, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 64, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 3, 8, 8, true, 1, 1, S<1, 32, 1, 8>, 8>, - DeviceGemmXdlSplitKCShuffle< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 128, 128, 64, 4, 8, 32, 32, 2, 2, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 3, 8, 8, true, 1, 1, S<1, 32, 1, 4>, 8>, - DeviceGemmXdlSplitKCShuffle< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 128, 64, 128, 4, 8, 32, 32, 2, 2, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 3, 8, 8, true, 1, 1, S<1, 16, 1, 8>, 8>, - DeviceGemmXdlSplitKCShuffle< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 64, 64, 64, 4, 8, 32, 32, 2, 2, S<1, 4, 16, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 16, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 3, 8, 8, true, 1, 1, S<1, 16, 1, 4>, 8>, - DeviceGemmXdlSplitKCShuffle< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 128, 64, 4, 8, 32, 32, 2, 1, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 64, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 3, 8, 8, true, 1, 1, S<1, 32, 1, 8>, 8>, - DeviceGemmXdlSplitKCShuffle< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 64, 128, 4, 8, 32, 32, 1, 2, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 64, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 3, 8, 8, true, 1, 1, S<1, 32, 1, 8>, 8>, - DeviceGemmXdlSplitKCShuffle< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 128, 128, 32, 4, 8, 32, 32, 2, 1, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 3, 8, 8, true, 1, 1, S<1, 32, 1, 4>, 8>, - DeviceGemmXdlSplitKCShuffle< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 128, 32, 128, 4, 8, 32, 32, 1, 2, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 3, 8, 8, true, 1, 1, S<1, 16, 1, 8>, 8>, - DeviceGemmXdlSplitKCShuffle< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 64, 64, 32, 4, 8, 32, 32, 2, 1, S<1, 4, 16, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 16, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 3, 8, 8, true, 1, 1, S<1, 16, 1, 4>, 8>, - DeviceGemmXdlSplitKCShuffle< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 64, 32, 64, 4, 8, 32, 32, 1, 2, S<1, 4, 16, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 16, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 3, 8, 8, true, 1, 1, S<1, 16, 1, 4>, 8> + //PipelineVersion::v1 + DeviceGemmXdlSplitKCShuffle< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 256, 128, 4, 8, 32, 32, 4, 2, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 64, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 3, 8, 8, true, 1, 1, S<1, 32, 1, 8>, 8, F16, PipelineVersion::v1>, + DeviceGemmXdlSplitKCShuffle< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 128, 256, 4, 8, 32, 32, 2, 4, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 64, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 3, 8, 8, true, 1, 1, S<1, 32, 1, 8>, 8, F16, PipelineVersion::v1>, + DeviceGemmXdlSplitKCShuffle< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 128, 128, 128, 4, 8, 32, 32, 4, 2, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 3, 8, 8, true, 1, 1, S<1, 16, 1, 8>, 8, F16, PipelineVersion::v1>, + DeviceGemmXdlSplitKCShuffle< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 128, 128, 4, 8, 32, 32, 2, 2, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 64, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 3, 8, 8, true, 1, 1, S<1, 32, 1, 8>, 8, F16, PipelineVersion::v1>, + DeviceGemmXdlSplitKCShuffle< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 128, 128, 64, 4, 8, 32, 32, 2, 2, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 3, 8, 8, true, 1, 1, S<1, 32, 1, 4>, 8, F16, PipelineVersion::v1>, + DeviceGemmXdlSplitKCShuffle< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 128, 64, 128, 4, 8, 32, 32, 2, 2, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 3, 8, 8, true, 1, 1, S<1, 16, 1, 8>, 8, F16, PipelineVersion::v1>, + DeviceGemmXdlSplitKCShuffle< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 64, 64, 64, 4, 8, 32, 32, 2, 2, S<1, 4, 16, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 16, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 3, 8, 8, true, 1, 1, S<1, 16, 1, 4>, 8, F16, PipelineVersion::v1>, + DeviceGemmXdlSplitKCShuffle< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 128, 64, 4, 8, 32, 32, 2, 1, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 64, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 3, 8, 8, true, 1, 1, S<1, 32, 1, 8>, 8, F16, PipelineVersion::v1>, + DeviceGemmXdlSplitKCShuffle< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 64, 128, 4, 8, 32, 32, 1, 2, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 64, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 3, 8, 8, true, 1, 1, S<1, 32, 1, 8>, 8, F16, PipelineVersion::v1>, + DeviceGemmXdlSplitKCShuffle< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 128, 128, 32, 4, 8, 32, 32, 2, 1, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 3, 8, 8, true, 1, 1, S<1, 32, 1, 4>, 8, F16, PipelineVersion::v1>, + DeviceGemmXdlSplitKCShuffle< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 128, 32, 128, 4, 8, 32, 32, 1, 2, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 3, 8, 8, true, 1, 1, S<1, 16, 1, 8>, 8, F16, PipelineVersion::v1>, + DeviceGemmXdlSplitKCShuffle< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 64, 64, 32, 4, 8, 32, 32, 2, 1, S<1, 4, 16, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 16, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 3, 8, 8, true, 1, 1, S<1, 16, 1, 4>, 8, F16, PipelineVersion::v1>, + DeviceGemmXdlSplitKCShuffle< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 64, 32, 64, 4, 8, 32, 32, 1, 2, S<1, 4, 16, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 16, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 3, 8, 8, true, 1, 1, S<1, 16, 1, 4>, 8, F16, PipelineVersion::v1>, + + //PipelineVersion::v2 + DeviceGemmXdlSplitKCShuffle< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 256, 128, 4, 8, 32, 32, 4, 2, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 64, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 3, 8, 8, true, 1, 1, S<1, 32, 1, 8>, 8, F16, PipelineVersion::v2>, + DeviceGemmXdlSplitKCShuffle< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 128, 256, 4, 8, 32, 32, 2, 4, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 64, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 3, 8, 8, true, 1, 1, S<1, 32, 1, 8>, 8, F16, PipelineVersion::v2>, + DeviceGemmXdlSplitKCShuffle< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 128, 128, 128, 4, 8, 32, 32, 4, 2, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 3, 8, 8, true, 1, 1, S<1, 16, 1, 8>, 8, F16, PipelineVersion::v2>, + DeviceGemmXdlSplitKCShuffle< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 128, 128, 4, 8, 32, 32, 2, 2, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 64, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 3, 8, 8, true, 1, 1, S<1, 32, 1, 8>, 8, F16, PipelineVersion::v2>, + DeviceGemmXdlSplitKCShuffle< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 128, 128, 64, 4, 8, 32, 32, 2, 2, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 3, 8, 8, true, 1, 1, S<1, 32, 1, 4>, 8, F16, PipelineVersion::v2>, + DeviceGemmXdlSplitKCShuffle< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 128, 64, 128, 4, 8, 32, 32, 2, 2, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 3, 8, 8, true, 1, 1, S<1, 16, 1, 8>, 8, F16, PipelineVersion::v2>, + DeviceGemmXdlSplitKCShuffle< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 64, 64, 64, 4, 8, 32, 32, 2, 2, S<1, 4, 16, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 16, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 3, 8, 8, true, 1, 1, S<1, 16, 1, 4>, 8, F16, PipelineVersion::v2>, + DeviceGemmXdlSplitKCShuffle< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 128, 64, 4, 8, 32, 32, 2, 1, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 64, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 3, 8, 8, true, 1, 1, S<1, 32, 1, 8>, 8, F16, PipelineVersion::v2>, + DeviceGemmXdlSplitKCShuffle< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 64, 128, 4, 8, 32, 32, 1, 2, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 64, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 3, 8, 8, true, 1, 1, S<1, 32, 1, 8>, 8, F16, PipelineVersion::v2>, + DeviceGemmXdlSplitKCShuffle< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 128, 128, 32, 4, 8, 32, 32, 2, 1, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 3, 8, 8, true, 1, 1, S<1, 32, 1, 4>, 8, F16, PipelineVersion::v2>, + DeviceGemmXdlSplitKCShuffle< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 128, 32, 128, 4, 8, 32, 32, 1, 2, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 3, 8, 8, true, 1, 1, S<1, 16, 1, 8>, 8, F16, PipelineVersion::v2>, + DeviceGemmXdlSplitKCShuffle< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 64, 64, 32, 4, 8, 32, 32, 2, 1, S<1, 4, 16, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 16, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 3, 8, 8, true, 1, 1, S<1, 16, 1, 4>, 8, F16, PipelineVersion::v2>, + DeviceGemmXdlSplitKCShuffle< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 64, 32, 64, 4, 8, 32, 32, 1, 2, S<1, 4, 16, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 16, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 3, 8, 8, true, 1, 1, S<1, 16, 1, 4>, 8, F16, PipelineVersion::v2> // clang-format on >; @@ -56,6 +83,8 @@ void add_device_gemm_xdl_splitk_f16_f16_f16_mk_nk_mn_instances( DeviceGemmSplitK>>& instances) { + add_device_operation_instances(instances, + device_gemm_xdl_splitk_f16_f16_f16_mk_nk_mn_generic_instances{}); add_device_operation_instances(instances, device_gemm_xdl_splitk_f16_f16_f16_mk_nk_mn_instances{}); } diff --git a/library/src/tensor_operation_instance/gpu/gemm_splitk/device_gemm_xdl_splitk_f16_f8_f16_km_kn_mn_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm_splitk/device_gemm_xdl_splitk_f16_f8_f16_km_kn_mn_instance.cpp new file mode 100644 index 0000000000000000000000000000000000000000..86a86c10353cff4e673e4303ec2dce5141e9f846 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/gemm_splitk/device_gemm_xdl_splitk_f16_f8_f16_km_kn_mn_instance.cpp @@ -0,0 +1,62 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_gemm_xdl_splitk_c_shuffle.hpp" + +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +using F8 = ck::f8_t; +using F16 = ck::half_t; +using F32 = float; + +using Row = ck::tensor_layout::gemm::RowMajor; +using Col = ck::tensor_layout::gemm::ColumnMajor; + +template +using S = ck::Sequence; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; + +static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default; + +// Compilation parameters for a[m, k] * b[k, n] = c[m, n] +using device_gemm_xdl_splitk_f16_f8_f16_km_kn_mn_instances = std::tuple< + // clang-format off + //#########################|AData| BData| CData| AccData| ALayout| BLayout| CLayout| A| B| C| GEMM| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| Compute| + //#########################| Type| Type| Type| Type| | | | Elementwise| Elementwise| Elementwise|Specialization| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MXdlPerWave_MWaveMPerXdl| ScalarPerVector| Type| + //#########################| | | | | | | | Operation| Operation| Operation| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NXdlPerWave_NWaveNPerXdl| _NWaveNPerXdl| | + //#########################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceGemmXdlSplitKCShuffle< F16, F8, F16, F32, Col, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 256, 128, 4, 8, 32, 32, 4, 2, S<1, 4, 64, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 4, 8, true, S<1, 4, 64, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 2, 8, true, 1, 1, S<1, 32, 1, 8>, 8, F16>, + DeviceGemmXdlSplitKCShuffle< F16, F8, F16, F32, Col, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 128, 256, 4, 8, 32, 32, 2, 4, S<1, 4, 64, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 2, 8, true, S<1, 4, 64, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 4, 8, true, 1, 1, S<1, 32, 1, 8>, 8, F16>, + DeviceGemmXdlSplitKCShuffle< F16, F8, F16, F32, Col, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 128, 128, 128, 4, 8, 32, 32, 4, 2, S<1, 4, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 4, 8, true, S<1, 4, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 4, 8, true, 1, 1, S<1, 16, 1, 8>, 8, F16>, + DeviceGemmXdlSplitKCShuffle< F16, F8, F16, F32, Col, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 128, 128, 4, 8, 32, 32, 2, 2, S<1, 4, 64, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 2, 8, true, S<1, 4, 64, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 2, 8, true, 1, 1, S<1, 32, 1, 8>, 8, F16>, + DeviceGemmXdlSplitKCShuffle< F16, F8, F16, F32, Col, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 128, 128, 64, 4, 8, 32, 32, 2, 2, S<1, 4, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 4, 8, true, S<1, 4, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 2, 8, true, 1, 1, S<1, 32, 1, 4>, 8, F16>, + DeviceGemmXdlSplitKCShuffle< F16, F8, F16, F32, Col, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 128, 64, 128, 4, 8, 32, 32, 2, 2, S<1, 4, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 2, 8, true, S<1, 4, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 4, 8, true, 1, 1, S<1, 16, 1, 8>, 8, F16>, + DeviceGemmXdlSplitKCShuffle< F16, F8, F16, F32, Col, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 128, 64, 4, 8, 32, 32, 2, 1, S<1, 4, 64, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 2, 8, true, S<1, 4, 64, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 1, 8, true, 1, 1, S<1, 16, 1, 4>, 8, F16>, + DeviceGemmXdlSplitKCShuffle< F16, F8, F16, F32, Col, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 64, 128, 4, 8, 32, 32, 1, 2, S<1, 4, 64, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 1, 8, true, S<1, 4, 64, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 2, 8, true, 1, 1, S<1, 32, 1, 8>, 8, F16> + // clang-format on + >; + +void add_device_gemm_xdl_splitk_f16_f8_f16_km_kn_mn_instances( + std::vector>>& + instances) +{ + add_device_operation_instances(instances, + device_gemm_xdl_splitk_f16_f8_f16_km_kn_mn_instances{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/gemm_splitk/device_gemm_xdl_splitk_f16_f8_f16_km_nk_mn_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm_splitk/device_gemm_xdl_splitk_f16_f8_f16_km_nk_mn_instance.cpp new file mode 100644 index 0000000000000000000000000000000000000000..c423b493eb22d115a7b74e7fd4327aea3e26907a --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/gemm_splitk/device_gemm_xdl_splitk_f16_f8_f16_km_nk_mn_instance.cpp @@ -0,0 +1,62 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_gemm_xdl_splitk_c_shuffle.hpp" + +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +using F8 = ck::f8_t; +using F16 = ck::half_t; +using F32 = float; + +using Row = ck::tensor_layout::gemm::RowMajor; +using Col = ck::tensor_layout::gemm::ColumnMajor; + +template +using S = ck::Sequence; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; + +static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default; + +// Compilation parameters for a[m, k] * b[k, n] = c[m, n] +using device_gemm_xdl_splitk_f16_f8_f16_km_nk_mn_instances = std::tuple< + // clang-format off + //#########################|AData| BData| CData| AccData| ALayout| BLayout| CLayout| A| B| C| GEMM| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| Compute| + //#########################| Type| Type| Type| Type| | | | Elementwise| Elementwise| Elementwise|Specialization| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MXdlPerWave_MWaveMPerXdl| ScalarPerVector| Type| + //#########################| | | | | | | | Operation| Operation| Operation| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NXdlPerWave_NWaveNPerXdl| _NWaveNPerXdl| | + //#########################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceGemmXdlSplitKCShuffle< F16, F8, F16, F32, Col, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 256, 128, 4, 8, 32, 32, 4, 2, S<1, 4, 64, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 4, 8, true, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, 1, 1, S<1, 32, 1, 8>, 8, F16>, + DeviceGemmXdlSplitKCShuffle< F16, F8, F16, F32, Col, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 128, 256, 4, 8, 32, 32, 2, 4, S<1, 4, 64, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 2, 8, true, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, 1, 1, S<1, 32, 1, 8>, 8, F16>, + DeviceGemmXdlSplitKCShuffle< F16, F8, F16, F32, Col, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 128, 128, 128, 4, 8, 32, 32, 4, 2, S<1, 4, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 4, 8, true, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, 1, 1, S<1, 16, 1, 8>, 8, F16>, + DeviceGemmXdlSplitKCShuffle< F16, F8, F16, F32, Col, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 128, 128, 4, 8, 32, 32, 2, 2, S<1, 4, 64, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 2, 8, true, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, 1, 1, S<1, 32, 1, 8>, 8, F16>, + DeviceGemmXdlSplitKCShuffle< F16, F8, F16, F32, Col, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 128, 128, 64, 4, 8, 32, 32, 2, 2, S<1, 4, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 4, 8, true, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, 1, 1, S<1, 32, 1, 4>, 8, F16>, + DeviceGemmXdlSplitKCShuffle< F16, F8, F16, F32, Col, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 128, 64, 128, 4, 8, 32, 32, 2, 2, S<1, 4, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 2, 8, true, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, 1, 1, S<1, 16, 1, 8>, 8, F16>, + DeviceGemmXdlSplitKCShuffle< F16, F8, F16, F32, Col, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 128, 64, 4, 8, 32, 32, 2, 1, S<1, 4, 64, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 2, 8, true, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, 1, 1, S<1, 16, 1, 4>, 8, F16>, + DeviceGemmXdlSplitKCShuffle< F16, F8, F16, F32, Col, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 64, 128, 4, 8, 32, 32, 1, 2, S<1, 4, 64, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 1, 8, true, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, 1, 1, S<1, 32, 1, 8>, 8, F16> + // clang-format on + >; + +void add_device_gemm_xdl_splitk_f16_f8_f16_km_nk_mn_instances( + std::vector>>& + instances) +{ + add_device_operation_instances(instances, + device_gemm_xdl_splitk_f16_f8_f16_km_nk_mn_instances{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/gemm_splitk/device_gemm_xdl_splitk_f16_f8_f16_mk_kn_mn_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm_splitk/device_gemm_xdl_splitk_f16_f8_f16_mk_kn_mn_instance.cpp new file mode 100644 index 0000000000000000000000000000000000000000..6168695b7fba5b8816fd8f0bdd8df943be8e826b --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/gemm_splitk/device_gemm_xdl_splitk_f16_f8_f16_mk_kn_mn_instance.cpp @@ -0,0 +1,83 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_gemm_xdl_splitk_c_shuffle.hpp" + +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +using F8 = ck::f8_t; +using F16 = ck::half_t; +using F32 = float; + +using Row = ck::tensor_layout::gemm::RowMajor; +using Col = ck::tensor_layout::gemm::ColumnMajor; + +template +using S = ck::Sequence; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; + +// static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default; +static constexpr auto GemmMNPadding = ck::tensor_operation::device::GemmSpecialization::MNPadding; + +using device_gemm_xdl_splitk_f16_f8_f16_mk_kn_mn_generic_instances = std::tuple< + // clang-format off + //#########################|AData| BData| CData| AccData| ALayout| BLayout| CLayout| A| B| C| GEMM| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| Compute| + //#########################| Type| Type| Type| Type| | | | Elementwise| Elementwise| Elementwise|Specialization| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MXdlPerWave_MWaveMPerXdl| ScalarPerVector| Type| + //#########################| | | | | | | | Operation| Operation| Operation| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NXdlPerWave_NWaveNPerXdl| _NWaveNPerXdl| | + //#########################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceGemmXdlSplitKCShuffle< F16, F8, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 128, 32, 64, 4, 8, 32, 32, 1, 1, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 1, 8, true, S<1, 4, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 1, 8, true, 1, 1, S<1, 16, 1, 8>, 2, F16> + // clang-format on + >; + +// Compilation parameters for a[m, k] * b[k, n] = c[m, n] +using device_gemm_xdl_splitk_f16_f8_f16_mk_kn_mn_instances = std::tuple< + // clang-format off + //#########################|AData| BData| CData| AccData| ALayout| BLayout| CLayout| A| B| C| GEMM| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| Compute| + //#########################| Type| Type| Type| Type| | | | Elementwise| Elementwise| Elementwise|Specialization| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MXdlPerWave_MWaveMPerXdl| ScalarPerVector| Type| + //#########################| | | | | | | | Operation| Operation| Operation| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NXdlPerWave_NWaveNPerXdl| _NWaveNPerXdl| | + //#########################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceGemmXdlSplitKCShuffle< F16, F8, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 256, 256, 128, 4, 8, 32, 32, 4, 2, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 64, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 2, 8, true, 1, 1, S<1, 32, 1, 8>, 8, F16>, + DeviceGemmXdlSplitKCShuffle< F16, F8, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 256, 128, 256, 4, 8, 32, 32, 2, 4, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 64, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 4, 8, true, 1, 1, S<1, 32, 1, 8>, 8, F16>, + DeviceGemmXdlSplitKCShuffle< F16, F8, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 128, 128, 128, 4, 8, 32, 32, 4, 2, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 4, 8, true, 1, 1, S<1, 16, 1, 8>, 8, F16>, + DeviceGemmXdlSplitKCShuffle< F16, F8, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 256, 64, 192, 4, 8, 32, 32, 1, 3, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 48, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 2, 8, true, 1, 1, S<1, 32, 1, 8>, 8, F16>, + DeviceGemmXdlSplitKCShuffle< F16, F8, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 256, 192, 64, 4, 8, 32, 32, 3, 1, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 2, 8, true, 1, 1, S<1, 32, 1, 8>, 8, F16>, + DeviceGemmXdlSplitKCShuffle< F16, F8, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 256, 128, 128, 4, 8, 32, 32, 2, 2, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 64, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 2, 8, true, 1, 1, S<1, 32, 1, 8>, 8, F16>, + DeviceGemmXdlSplitKCShuffle< F16, F8, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 128, 128, 64, 4, 8, 32, 32, 2, 2, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 2, 8, true, 1, 1, S<1, 32, 1, 4>, 8, F16>, + DeviceGemmXdlSplitKCShuffle< F16, F8, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 128, 64, 128, 4, 8, 32, 32, 2, 2, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 4, 8, true, 1, 1, S<1, 16, 1, 8>, 8, F16>, + DeviceGemmXdlSplitKCShuffle< F16, F8, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 256, 128, 64, 4, 8, 32, 32, 2, 1, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 64, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 1, 8, true, 1, 1, S<1, 16, 1, 4>, 8, F16>, + DeviceGemmXdlSplitKCShuffle< F16, F8, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 256, 64, 128, 4, 8, 32, 32, 1, 2, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 64, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 2, 8, true, 1, 1, S<1, 32, 1, 8>, 8, F16>, + DeviceGemmXdlSplitKCShuffle< F16, F8, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 128, 32, 192, 4, 8, 32, 32, 1, 3, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 24, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 8, 8, true, 1, 1, S<1, 16, 1, 8>, 8, F16>, + DeviceGemmXdlSplitKCShuffle< F16, F8, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 128, 192, 32, 4, 8, 32, 32, 3, 1, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 1, 8, true, 1, 1, S<1, 32, 1, 4>, 8, F16>, + DeviceGemmXdlSplitKCShuffle< F16, F8, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 128, 32, 64, 4, 8, 32, 32, 1, 1, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 2, 8, true, 1, 1, S<1, 16, 1, 8>, 8, F16>, + DeviceGemmXdlSplitKCShuffle< F16, F8, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 128, 64, 32, 4, 8, 32, 32, 1, 1, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 1, 8, true, 1, 1, S<1, 32, 1, 4>, 8, F16>, + DeviceGemmXdlSplitKCShuffle< F16, F8, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 128, 32, 128, 4, 8, 32, 32, 1, 2, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 4, 8, true, 1, 1, S<1, 16, 1, 8>, 8, F16>, + DeviceGemmXdlSplitKCShuffle< F16, F8, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 128, 128, 32, 4, 8, 32, 32, 2, 1, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 1, 8, true, 1, 1, S<1, 32, 1, 4>, 8, F16> + // clang-format on + >; + +void add_device_gemm_xdl_splitk_f16_f8_f16_mk_kn_mn_instances( + std::vector>>& + instances) +{ + add_device_operation_instances(instances, + device_gemm_xdl_splitk_f16_f8_f16_mk_kn_mn_generic_instances{}); + add_device_operation_instances(instances, + device_gemm_xdl_splitk_f16_f8_f16_mk_kn_mn_instances{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/gemm_splitk/device_gemm_xdl_splitk_f16_f8_f16_mk_nk_mn_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm_splitk/device_gemm_xdl_splitk_f16_f8_f16_mk_nk_mn_instance.cpp new file mode 100644 index 0000000000000000000000000000000000000000..a2dd9532a1b43b3cc657f9dc16760619d758c56c --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/gemm_splitk/device_gemm_xdl_splitk_f16_f8_f16_mk_nk_mn_instance.cpp @@ -0,0 +1,79 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_gemm_xdl_splitk_c_shuffle.hpp" + +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +using F8 = ck::f8_t; +using F16 = ck::half_t; +using F32 = float; + +using Row = ck::tensor_layout::gemm::RowMajor; +using Col = ck::tensor_layout::gemm::ColumnMajor; + +template +using S = ck::Sequence; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; + +static constexpr auto MNKPadding = ck::tensor_operation::device::GemmSpecialization::MNKPadding; + +using device_gemm_xdl_splitk_f16_f8_f16_mk_nk_mn_generic_instances = std::tuple< + // clang-format off + //#########################|AData| BData| CData| AccData| ALayout| BLayout| CLayout| A| B| C| GEMM| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| Compute| + //#########################| Type| Type| Type| Type| | | | Elementwise| Elementwise| Elementwise|Specialization| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MXdlPerWave_MWaveMPerXdl| ScalarPerVector| Type| + //#########################| | | | | | | | Operation| Operation| Operation| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NXdlPerWave_NWaveNPerXdl| _NWaveNPerXdl| | + //#########################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceGemmXdlSplitKCShuffle< F16, F8, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, MNKPadding, 64, 32, 64, 4, 8, 32, 32, 1, 2, S<1, 4, 16, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 1, 8, true, S<1, 4, 16, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 3, 1, 8, true, 1, 1, S<1, 16, 1, 4>, 2, F16> + // clang-format on + >; + +// Compilation parameters for a[m, k] * b[k, n] = c[m, n] +using device_gemm_xdl_splitk_f16_f8_f16_mk_nk_mn_instances = std::tuple< + // clang-format off + //#########################|AData| BData| CData| AccData| ALayout| BLayout| CLayout| A| B| C| GEMM| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| Compute| + //#########################| Type| Type| Type| Type| | | | Elementwise| Elementwise| Elementwise|Specialization| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MXdlPerWave_MWaveMPerXdl| ScalarPerVector| Type| + //#########################| | | | | | | | Operation| Operation| Operation| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NXdlPerWave_NWaveNPerXdl| _NWaveNPerXdl| | + //#########################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceGemmXdlSplitKCShuffle< F16, F8, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, MNKPadding, 256, 256, 128, 4, 8, 32, 32, 4, 2, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 64, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 3, 8, 8, true, 1, 1, S<1, 32, 1, 8>, 8, F16>, + DeviceGemmXdlSplitKCShuffle< F16, F8, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, MNKPadding, 256, 128, 256, 4, 8, 32, 32, 2, 4, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 64, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 3, 8, 8, true, 1, 1, S<1, 32, 1, 8>, 8, F16>, + DeviceGemmXdlSplitKCShuffle< F16, F8, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, MNKPadding, 128, 128, 128, 4, 8, 32, 32, 4, 2, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 3, 8, 8, true, 1, 1, S<1, 16, 1, 8>, 8, F16>, + DeviceGemmXdlSplitKCShuffle< F16, F8, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, MNKPadding, 256, 128, 128, 4, 8, 32, 32, 2, 2, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 64, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 3, 8, 8, true, 1, 1, S<1, 32, 1, 8>, 8, F16>, + DeviceGemmXdlSplitKCShuffle< F16, F8, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, MNKPadding, 128, 128, 64, 4, 8, 32, 32, 2, 2, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 3, 8, 8, true, 1, 1, S<1, 32, 1, 4>, 8, F16>, + DeviceGemmXdlSplitKCShuffle< F16, F8, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, MNKPadding, 128, 64, 128, 4, 8, 32, 32, 2, 2, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 3, 8, 8, true, 1, 1, S<1, 16, 1, 8>, 8, F16>, + DeviceGemmXdlSplitKCShuffle< F16, F8, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, MNKPadding, 64, 64, 64, 4, 8, 32, 32, 2, 2, S<1, 4, 16, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 16, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 3, 8, 8, true, 1, 1, S<1, 16, 1, 4>, 8, F16>, + DeviceGemmXdlSplitKCShuffle< F16, F8, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, MNKPadding, 256, 128, 64, 4, 8, 32, 32, 2, 1, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 64, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 3, 8, 8, true, 1, 1, S<1, 32, 1, 8>, 8, F16>, + DeviceGemmXdlSplitKCShuffle< F16, F8, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, MNKPadding, 256, 64, 128, 4, 8, 32, 32, 1, 2, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 64, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 3, 8, 8, true, 1, 1, S<1, 32, 1, 8>, 8, F16>, + DeviceGemmXdlSplitKCShuffle< F16, F8, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, MNKPadding, 128, 128, 32, 4, 8, 32, 32, 2, 1, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 3, 8, 8, true, 1, 1, S<1, 32, 1, 4>, 8, F16>, + DeviceGemmXdlSplitKCShuffle< F16, F8, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, MNKPadding, 128, 32, 128, 4, 8, 32, 32, 1, 2, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 3, 8, 8, true, 1, 1, S<1, 16, 1, 8>, 8, F16>, + DeviceGemmXdlSplitKCShuffle< F16, F8, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, MNKPadding, 64, 64, 32, 4, 8, 32, 32, 2, 1, S<1, 4, 16, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 16, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 3, 8, 8, true, 1, 1, S<1, 16, 1, 4>, 8, F16>, + DeviceGemmXdlSplitKCShuffle< F16, F8, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, MNKPadding, 64, 32, 64, 4, 8, 32, 32, 1, 2, S<1, 4, 16, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 16, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 3, 8, 8, true, 1, 1, S<1, 16, 1, 4>, 8, F16> + // clang-format on + >; + +void add_device_gemm_xdl_splitk_f16_f8_f16_mk_nk_mn_instances( + std::vector>>& + instances) +{ + add_device_operation_instances(instances, + device_gemm_xdl_splitk_f16_f8_f16_mk_nk_mn_generic_instances{}); + add_device_operation_instances(instances, + device_gemm_xdl_splitk_f16_f8_f16_mk_nk_mn_instances{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/gemm_splitk/device_gemm_xdl_splitk_f32_f32_f32_km_kn_mn_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm_splitk/device_gemm_xdl_splitk_f32_f32_f32_km_kn_mn_instance.cpp index f27b2199e0c9dcb1e91566cb0aa0a05edaef5b0e..9d15ccd362774c77a9613046995d44761b8ccb49 100644 --- a/library/src/tensor_operation_instance/gpu/gemm_splitk/device_gemm_xdl_splitk_f32_f32_f32_km_kn_mn_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/gemm_splitk/device_gemm_xdl_splitk_f32_f32_f32_km_kn_mn_instance.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #include diff --git a/library/src/tensor_operation_instance/gpu/gemm_splitk/device_gemm_xdl_splitk_f32_f32_f32_km_nk_mn_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm_splitk/device_gemm_xdl_splitk_f32_f32_f32_km_nk_mn_instance.cpp index b9a1095570aaf33aa8f3f93710cbb6d3ae276430..4e9ad58742c552f437dd4b8174cb675984dfc681 100644 --- a/library/src/tensor_operation_instance/gpu/gemm_splitk/device_gemm_xdl_splitk_f32_f32_f32_km_nk_mn_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/gemm_splitk/device_gemm_xdl_splitk_f32_f32_f32_km_nk_mn_instance.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #include diff --git a/library/src/tensor_operation_instance/gpu/gemm_splitk/device_gemm_xdl_splitk_f32_f32_f32_mk_kn_mn_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm_splitk/device_gemm_xdl_splitk_f32_f32_f32_mk_kn_mn_instance.cpp index 44e5f597d0d714438c97a9ce7d591736f4fe7f48..330e5aff907af00920dcc4f3cb1f5f2a3b50e6fa 100644 --- a/library/src/tensor_operation_instance/gpu/gemm_splitk/device_gemm_xdl_splitk_f32_f32_f32_mk_kn_mn_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/gemm_splitk/device_gemm_xdl_splitk_f32_f32_f32_mk_kn_mn_instance.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #include diff --git a/library/src/tensor_operation_instance/gpu/gemm_splitk/device_gemm_xdl_splitk_f32_f32_f32_mk_nk_mn_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm_splitk/device_gemm_xdl_splitk_f32_f32_f32_mk_nk_mn_instance.cpp index f3a9063f7fb847b9c303ea836f97ed4cf3173ab9..0db3a15d2a4816cca200820346ccf7c9a711fa31 100644 --- a/library/src/tensor_operation_instance/gpu/gemm_splitk/device_gemm_xdl_splitk_f32_f32_f32_mk_nk_mn_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/gemm_splitk/device_gemm_xdl_splitk_f32_f32_f32_mk_nk_mn_instance.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #include diff --git a/library/src/tensor_operation_instance/gpu/gemm_splitk/device_gemm_xdl_splitk_f8_f16_f16_km_kn_mn_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm_splitk/device_gemm_xdl_splitk_f8_f16_f16_km_kn_mn_instance.cpp new file mode 100644 index 0000000000000000000000000000000000000000..face3f456d0bbe7dc1363d2c3949ecf1ff8d2bd0 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/gemm_splitk/device_gemm_xdl_splitk_f8_f16_f16_km_kn_mn_instance.cpp @@ -0,0 +1,62 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_gemm_xdl_splitk_c_shuffle.hpp" + +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +using F8 = ck::f8_t; +using F16 = ck::half_t; +using F32 = float; + +using Row = ck::tensor_layout::gemm::RowMajor; +using Col = ck::tensor_layout::gemm::ColumnMajor; + +template +using S = ck::Sequence; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; + +static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default; + +// Compilation parameters for a[m, k] * b[k, n] = c[m, n] +using device_gemm_xdl_splitk_f8_f16_f16_km_kn_mn_instances = std::tuple< + // clang-format off + //#########################|AData| BData| CData| AccData| ALayout| BLayout| CLayout| A| B| C| GEMM| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| Compute| + //#########################| Type| Type| Type| Type| | | | Elementwise| Elementwise| Elementwise|Specialization| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MXdlPerWave_MWaveMPerXdl| ScalarPerVector| Type| + //#########################| | | | | | | | Operation| Operation| Operation| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NXdlPerWave_NWaveNPerXdl| _NWaveNPerXdl| | + //#########################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceGemmXdlSplitKCShuffle< F8, F16, F16, F32, Col, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 256, 128, 4, 8, 32, 32, 4, 2, S<1, 4, 64, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 4, 8, true, S<1, 4, 64, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 2, 8, true, 1, 1, S<1, 32, 1, 8>, 8, F16>, + DeviceGemmXdlSplitKCShuffle< F8, F16, F16, F32, Col, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 128, 256, 4, 8, 32, 32, 2, 4, S<1, 4, 64, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 2, 8, true, S<1, 4, 64, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 4, 8, true, 1, 1, S<1, 32, 1, 8>, 8, F16>, + DeviceGemmXdlSplitKCShuffle< F8, F16, F16, F32, Col, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 128, 128, 128, 4, 8, 32, 32, 4, 2, S<1, 4, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 4, 8, true, S<1, 4, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 4, 8, true, 1, 1, S<1, 16, 1, 8>, 8, F16>, + DeviceGemmXdlSplitKCShuffle< F8, F16, F16, F32, Col, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 128, 128, 4, 8, 32, 32, 2, 2, S<1, 4, 64, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 2, 8, true, S<1, 4, 64, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 2, 8, true, 1, 1, S<1, 32, 1, 8>, 8, F16>, + DeviceGemmXdlSplitKCShuffle< F8, F16, F16, F32, Col, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 128, 128, 64, 4, 8, 32, 32, 2, 2, S<1, 4, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 4, 8, true, S<1, 4, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 2, 8, true, 1, 1, S<1, 32, 1, 4>, 8, F16>, + DeviceGemmXdlSplitKCShuffle< F8, F16, F16, F32, Col, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 128, 64, 128, 4, 8, 32, 32, 2, 2, S<1, 4, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 2, 8, true, S<1, 4, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 4, 8, true, 1, 1, S<1, 16, 1, 8>, 8, F16>, + DeviceGemmXdlSplitKCShuffle< F8, F16, F16, F32, Col, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 128, 64, 4, 8, 32, 32, 2, 1, S<1, 4, 64, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 2, 8, true, S<1, 4, 64, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 1, 8, true, 1, 1, S<1, 16, 1, 4>, 8, F16>, + DeviceGemmXdlSplitKCShuffle< F8, F16, F16, F32, Col, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 64, 128, 4, 8, 32, 32, 1, 2, S<1, 4, 64, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 1, 8, true, S<1, 4, 64, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 2, 8, true, 1, 1, S<1, 32, 1, 8>, 8, F16> + // clang-format on + >; + +void add_device_gemm_xdl_splitk_f8_f16_f16_km_kn_mn_instances( + std::vector>>& + instances) +{ + add_device_operation_instances(instances, + device_gemm_xdl_splitk_f8_f16_f16_km_kn_mn_instances{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/gemm_splitk/device_gemm_xdl_splitk_f8_f16_f16_km_nk_mn_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm_splitk/device_gemm_xdl_splitk_f8_f16_f16_km_nk_mn_instance.cpp new file mode 100644 index 0000000000000000000000000000000000000000..727bd221a71bd8b4dca900e3e9600c87e4193a62 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/gemm_splitk/device_gemm_xdl_splitk_f8_f16_f16_km_nk_mn_instance.cpp @@ -0,0 +1,62 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_gemm_xdl_splitk_c_shuffle.hpp" + +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +using F8 = ck::f8_t; +using F16 = ck::half_t; +using F32 = float; + +using Row = ck::tensor_layout::gemm::RowMajor; +using Col = ck::tensor_layout::gemm::ColumnMajor; + +template +using S = ck::Sequence; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; + +static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default; + +// Compilation parameters for a[m, k] * b[k, n] = c[m, n] +using device_gemm_xdl_splitk_f8_f16_f16_km_nk_mn_instances = std::tuple< + // clang-format off + //#########################|AData| BData| CData| AccData| ALayout| BLayout| CLayout| A| B| C| GEMM| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| Compute| + //#########################| Type| Type| Type| Type| | | | Elementwise| Elementwise| Elementwise|Specialization| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MXdlPerWave_MWaveMPerXdl| ScalarPerVector| Type| + //#########################| | | | | | | | Operation| Operation| Operation| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NXdlPerWave_NWaveNPerXdl| _NWaveNPerXdl| | + //#########################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceGemmXdlSplitKCShuffle< F8, F16, F16, F32, Col, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 256, 128, 4, 8, 32, 32, 4, 2, S<1, 4, 64, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 4, 8, true, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, 1, 1, S<1, 32, 1, 8>, 8, F16>, + DeviceGemmXdlSplitKCShuffle< F8, F16, F16, F32, Col, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 128, 256, 4, 8, 32, 32, 2, 4, S<1, 4, 64, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 2, 8, true, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, 1, 1, S<1, 32, 1, 8>, 8, F16>, + DeviceGemmXdlSplitKCShuffle< F8, F16, F16, F32, Col, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 128, 128, 128, 4, 8, 32, 32, 4, 2, S<1, 4, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 4, 8, true, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, 1, 1, S<1, 16, 1, 8>, 8, F16>, + DeviceGemmXdlSplitKCShuffle< F8, F16, F16, F32, Col, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 128, 128, 4, 8, 32, 32, 2, 2, S<1, 4, 64, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 2, 8, true, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, 1, 1, S<1, 32, 1, 8>, 8, F16>, + DeviceGemmXdlSplitKCShuffle< F8, F16, F16, F32, Col, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 128, 128, 64, 4, 8, 32, 32, 2, 2, S<1, 4, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 4, 8, true, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, 1, 1, S<1, 32, 1, 4>, 8, F16>, + DeviceGemmXdlSplitKCShuffle< F8, F16, F16, F32, Col, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 128, 64, 128, 4, 8, 32, 32, 2, 2, S<1, 4, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 2, 8, true, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, 1, 1, S<1, 16, 1, 8>, 8, F16>, + DeviceGemmXdlSplitKCShuffle< F8, F16, F16, F32, Col, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 128, 64, 4, 8, 32, 32, 2, 1, S<1, 4, 64, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 2, 8, true, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, 1, 1, S<1, 16, 1, 4>, 8, F16>, + DeviceGemmXdlSplitKCShuffle< F8, F16, F16, F32, Col, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 64, 128, 4, 8, 32, 32, 1, 2, S<1, 4, 64, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 1, 8, true, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, 1, 1, S<1, 32, 1, 8>, 8, F16> + // clang-format on + >; + +void add_device_gemm_xdl_splitk_f8_f16_f16_km_nk_mn_instances( + std::vector>>& + instances) +{ + add_device_operation_instances(instances, + device_gemm_xdl_splitk_f8_f16_f16_km_nk_mn_instances{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/gemm_splitk/device_gemm_xdl_splitk_f8_f16_f16_mk_kn_mn_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm_splitk/device_gemm_xdl_splitk_f8_f16_f16_mk_kn_mn_instance.cpp new file mode 100644 index 0000000000000000000000000000000000000000..6f378fd281dd20078ab91a3e7f859001ee971485 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/gemm_splitk/device_gemm_xdl_splitk_f8_f16_f16_mk_kn_mn_instance.cpp @@ -0,0 +1,71 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_gemm_xdl_splitk_c_shuffle.hpp" + +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +using F8 = ck::f8_t; +using F16 = ck::half_t; +using F32 = float; + +using Row = ck::tensor_layout::gemm::RowMajor; +using Col = ck::tensor_layout::gemm::ColumnMajor; + +template +using S = ck::Sequence; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; + +// static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default; +static constexpr auto GemmMNPadding = ck::tensor_operation::device::GemmSpecialization::MNPadding; + +// Compilation parameters for a[m, k] * b[k, n] = c[m, n] +using device_gemm_xdl_splitk_f8_f16_f16_mk_kn_mn_instances = std::tuple< + // clang-format off + //#########################|AData| BData| CData| AccData| ALayout| BLayout| CLayout| A| B| C| GEMM| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| Compute| + //#########################| Type| Type| Type| Type| | | | Elementwise| Elementwise| Elementwise|Specialization| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MXdlPerWave_MWaveMPerXdl| ScalarPerVector| Type| + //#########################| | | | | | | | Operation| Operation| Operation| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NXdlPerWave_NWaveNPerXdl| _NWaveNPerXdl| | + //#########################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceGemmXdlSplitKCShuffle< F8, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 256, 256, 128, 4, 8, 32, 32, 4, 2, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 64, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 2, 8, true, 1, 1, S<1, 32, 1, 8>, 8, F16>, + DeviceGemmXdlSplitKCShuffle< F8, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 256, 128, 256, 4, 8, 32, 32, 2, 4, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 64, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 4, 8, true, 1, 1, S<1, 32, 1, 8>, 8, F16>, + DeviceGemmXdlSplitKCShuffle< F8, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 128, 128, 128, 4, 8, 32, 32, 4, 2, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 4, 8, true, 1, 1, S<1, 16, 1, 8>, 8, F16>, + DeviceGemmXdlSplitKCShuffle< F8, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 256, 64, 192, 4, 8, 32, 32, 1, 3, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 48, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 2, 8, true, 1, 1, S<1, 32, 1, 8>, 8, F16>, + DeviceGemmXdlSplitKCShuffle< F8, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 256, 192, 64, 4, 8, 32, 32, 3, 1, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 2, 8, true, 1, 1, S<1, 32, 1, 8>, 8, F16>, + DeviceGemmXdlSplitKCShuffle< F8, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 256, 128, 128, 4, 8, 32, 32, 2, 2, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 64, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 2, 8, true, 1, 1, S<1, 32, 1, 8>, 8, F16>, + DeviceGemmXdlSplitKCShuffle< F8, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 128, 128, 64, 4, 8, 32, 32, 2, 2, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 2, 8, true, 1, 1, S<1, 32, 1, 4>, 8, F16>, + DeviceGemmXdlSplitKCShuffle< F8, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 128, 64, 128, 4, 8, 32, 32, 2, 2, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 4, 8, true, 1, 1, S<1, 16, 1, 8>, 8, F16>, + DeviceGemmXdlSplitKCShuffle< F8, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 256, 128, 64, 4, 8, 32, 32, 2, 1, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 64, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 1, 8, true, 1, 1, S<1, 16, 1, 4>, 8, F16>, + DeviceGemmXdlSplitKCShuffle< F8, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 256, 64, 128, 4, 8, 32, 32, 1, 2, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 64, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 2, 8, true, 1, 1, S<1, 32, 1, 8>, 8, F16>, + DeviceGemmXdlSplitKCShuffle< F8, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 128, 32, 192, 4, 8, 32, 32, 1, 3, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 24, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 8, 8, true, 1, 1, S<1, 16, 1, 8>, 8, F16>, + DeviceGemmXdlSplitKCShuffle< F8, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 128, 192, 32, 4, 8, 32, 32, 3, 1, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 1, 8, true, 1, 1, S<1, 32, 1, 4>, 8, F16>, + DeviceGemmXdlSplitKCShuffle< F8, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 128, 32, 64, 4, 8, 32, 32, 1, 1, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 2, 8, true, 1, 1, S<1, 16, 1, 8>, 8, F16>, + DeviceGemmXdlSplitKCShuffle< F8, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 128, 64, 32, 4, 8, 32, 32, 1, 1, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 1, 8, true, 1, 1, S<1, 32, 1, 4>, 8, F16>, + DeviceGemmXdlSplitKCShuffle< F8, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 128, 32, 128, 4, 8, 32, 32, 1, 2, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 4, 8, true, 1, 1, S<1, 16, 1, 8>, 8, F16>, + DeviceGemmXdlSplitKCShuffle< F8, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 128, 128, 32, 4, 8, 32, 32, 2, 1, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 1, 8, true, 1, 1, S<1, 32, 1, 4>, 8, F16> + // clang-format on + >; + +void add_device_gemm_xdl_splitk_f8_f16_f16_mk_kn_mn_instances( + std::vector>>& + instances) +{ + add_device_operation_instances(instances, + device_gemm_xdl_splitk_f8_f16_f16_mk_kn_mn_instances{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/gemm_splitk/device_gemm_xdl_splitk_f8_f16_f16_mk_nk_mn_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm_splitk/device_gemm_xdl_splitk_f8_f16_f16_mk_nk_mn_instance.cpp new file mode 100644 index 0000000000000000000000000000000000000000..70f4bdc0ecd4f7a5f6aa3c336308a95081d2a076 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/gemm_splitk/device_gemm_xdl_splitk_f8_f16_f16_mk_nk_mn_instance.cpp @@ -0,0 +1,67 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_gemm_xdl_splitk_c_shuffle.hpp" + +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +using F8 = ck::f8_t; +using F16 = ck::half_t; +using F32 = float; + +using Row = ck::tensor_layout::gemm::RowMajor; +using Col = ck::tensor_layout::gemm::ColumnMajor; + +template +using S = ck::Sequence; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; + +static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default; + +// Compilation parameters for a[m, k] * b[k, n] = c[m, n] +using device_gemm_xdl_splitk_f8_f16_f16_mk_nk_mn_instances = std::tuple< + // clang-format off + //#########################|AData| BData| CData| AccData| ALayout| BLayout| CLayout| A| B| C| GEMM| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| Compute| + //#########################| Type| Type| Type| Type| | | | Elementwise| Elementwise| Elementwise|Specialization| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MXdlPerWave_MWaveMPerXdl| ScalarPerVector| Type| + //#########################| | | | | | | | Operation| Operation| Operation| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NXdlPerWave_NWaveNPerXdl| _NWaveNPerXdl| | + //#########################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceGemmXdlSplitKCShuffle< F8, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 256, 128, 4, 8, 32, 32, 4, 2, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 64, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 3, 8, 8, true, 1, 1, S<1, 32, 1, 8>, 8, F16>, + DeviceGemmXdlSplitKCShuffle< F8, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 128, 256, 4, 8, 32, 32, 2, 4, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 64, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 3, 8, 8, true, 1, 1, S<1, 32, 1, 8>, 8, F16>, + DeviceGemmXdlSplitKCShuffle< F8, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 128, 128, 128, 4, 8, 32, 32, 4, 2, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 3, 8, 8, true, 1, 1, S<1, 16, 1, 8>, 8, F16>, + DeviceGemmXdlSplitKCShuffle< F8, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 128, 128, 4, 8, 32, 32, 2, 2, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 64, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 3, 8, 8, true, 1, 1, S<1, 32, 1, 8>, 8, F16>, + DeviceGemmXdlSplitKCShuffle< F8, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 128, 128, 64, 4, 8, 32, 32, 2, 2, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 3, 8, 8, true, 1, 1, S<1, 32, 1, 4>, 8, F16>, + DeviceGemmXdlSplitKCShuffle< F8, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 128, 64, 128, 4, 8, 32, 32, 2, 2, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 3, 8, 8, true, 1, 1, S<1, 16, 1, 8>, 8, F16>, + DeviceGemmXdlSplitKCShuffle< F8, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 64, 64, 64, 4, 8, 32, 32, 2, 2, S<1, 4, 16, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 16, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 3, 8, 8, true, 1, 1, S<1, 16, 1, 4>, 8, F16>, + DeviceGemmXdlSplitKCShuffle< F8, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 128, 64, 4, 8, 32, 32, 2, 1, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 64, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 3, 8, 8, true, 1, 1, S<1, 32, 1, 8>, 8, F16>, + DeviceGemmXdlSplitKCShuffle< F8, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 64, 128, 4, 8, 32, 32, 1, 2, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 64, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 3, 8, 8, true, 1, 1, S<1, 32, 1, 8>, 8, F16>, + DeviceGemmXdlSplitKCShuffle< F8, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 128, 128, 32, 4, 8, 32, 32, 2, 1, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 3, 8, 8, true, 1, 1, S<1, 32, 1, 4>, 8, F16>, + DeviceGemmXdlSplitKCShuffle< F8, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 128, 32, 128, 4, 8, 32, 32, 1, 2, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 3, 8, 8, true, 1, 1, S<1, 16, 1, 8>, 8, F16>, + DeviceGemmXdlSplitKCShuffle< F8, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 64, 64, 32, 4, 8, 32, 32, 2, 1, S<1, 4, 16, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 16, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 3, 8, 8, true, 1, 1, S<1, 16, 1, 4>, 8, F16>, + DeviceGemmXdlSplitKCShuffle< F8, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 64, 32, 64, 4, 8, 32, 32, 1, 2, S<1, 4, 16, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 16, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 3, 8, 8, true, 1, 1, S<1, 16, 1, 4>, 8, F16> + // clang-format on + >; + +void add_device_gemm_xdl_splitk_f8_f16_f16_mk_nk_mn_instances( + std::vector>>& + instances) +{ + add_device_operation_instances(instances, + device_gemm_xdl_splitk_f8_f16_f16_mk_nk_mn_instances{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/gemm_streamk/CMakeLists.txt b/library/src/tensor_operation_instance/gpu/gemm_streamk/CMakeLists.txt new file mode 100644 index 0000000000000000000000000000000000000000..d9baf3f062f5a68404e15562db19ae1e0bc17d6d --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/gemm_streamk/CMakeLists.txt @@ -0,0 +1,12 @@ +if(DTYPES MATCHES "fp16" OR NOT DEFINED DTYPES) +add_instance_library(device_gemm_streamk_instance + # device_gemm_xdl_streamk_f32_f32_f32_mk_kn_mn_instance.cpp + # device_gemm_xdl_streamk_f32_f32_f32_mk_nk_mn_instance.cpp + # device_gemm_xdl_streamk_f32_f32_f32_km_kn_mn_instance.cpp + # device_gemm_xdl_streamk_f32_f32_f32_km_nk_mn_instance.cpp + device_gemm_xdl_streamk_f16_f16_f16_mk_kn_mn_instance.cpp + # device_gemm_xdl_streamk_f16_f16_f16_mk_nk_mn_instance.cpp + # device_gemm_xdl_streamk_f16_f16_f16_km_kn_mn_instance.cpp + # device_gemm_xdl_streamk_f16_f16_f16_km_nk_mn_instance.cpp +) +endif() diff --git a/library/src/tensor_operation_instance/gpu/gemm_streamk/device_gemm_xdl_streamk_f16_f16_f16_mk_kn_mn_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm_streamk/device_gemm_xdl_streamk_f16_f16_f16_mk_kn_mn_instance.cpp new file mode 100644 index 0000000000000000000000000000000000000000..60e5d8f5fa9c3489cbb0584787beb3ec8b59ac21 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/gemm_streamk/device_gemm_xdl_streamk_f16_f16_f16_mk_kn_mn_instance.cpp @@ -0,0 +1,83 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_gemm_xdl_streamk.hpp" + +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +using F16 = ck::half_t; +using F32 = float; + +using Row = ck::tensor_layout::gemm::RowMajor; +using Col = ck::tensor_layout::gemm::ColumnMajor; + +template +using S = ck::Sequence; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; + +// static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default; +// static constexpr auto GemmMNPadding = +// ck::tensor_operation::device::GemmSpecialization::MNPadding; +using device_gemm_xdl_streamk_f16_f16_f16_mk_kn_mn_generic_instances = std::tuple< + // clang-format off + //##################|AData| BData| CData| AccData| ALayout| BLayout| CLayout| A| B| C| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| + //##################| Type| Type| Type| Type| | | | Elementwise| Elementwise| Elementwise| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MXdlPerWave_MWaveMPerXdl| ScalarPerVector| + //##################| | | | | | | | Operation| Operation| Operation| | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NXdlPerWave_NWaveNPerXdl| _NWaveNPerXdl| + //##################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceGemmXdlStreamK< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, 128, 128, 128, 4, 8, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 8, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, 1, 1, S<1, 16, 1, 8>, 2>, + DeviceGemmXdlStreamK< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, 128, 32, 64, 4, 8, 32, 32, 1, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 8, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, 1, 1, S<1, 16, 1, 8>, 2> + // clang-format on + >; + +// Compilation parameters for a[m, k] * b[k, n] = c[m, n] +using device_gemm_xdl_streamk_f16_f16_f16_mk_kn_mn_instances = std::tuple< + // clang-format off + //##################|AData| BData| CData| AccData| ALayout| BLayout| CLayout| A| B| C| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| + //##################| Type| Type| Type| Type| | | | Elementwise| Elementwise| Elementwise| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MXdlPerWave_MWaveMPerXdl| ScalarPerVector| + //##################| | | | | | | | Operation| Operation| Operation| | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NXdlPerWave_NWaveNPerXdl| _NWaveNPerXdl| + //##################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceGemmXdlStreamK< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, 256, 256, 128, 4, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGemmXdlStreamK< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, 256, 128, 256, 4, 8, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGemmXdlStreamK< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, 128, 128, 128, 4, 8, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 1, 1, 1, S<1, 16, 1, 8>, 8>, + DeviceGemmXdlStreamK< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, 256, 64, 192, 4, 8, 32, 32, 1, 3, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 48, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGemmXdlStreamK< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, 256, 192, 64, 4, 8, 32, 32, 3, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGemmXdlStreamK< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, 256, 128, 128, 4, 8, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGemmXdlStreamK< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, 128, 128, 64, 4, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, 1, 1, S<1, 32, 1, 4>, 8>, + DeviceGemmXdlStreamK< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, 128, 64, 128, 4, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 1, 1, 1, S<1, 16, 1, 8>, 8>, + DeviceGemmXdlStreamK< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, 256, 128, 64, 4, 8, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 16, 1, 4>, 8>, + DeviceGemmXdlStreamK< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, 256, 64, 128, 4, 8, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGemmXdlStreamK< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, 128, 32, 192, 4, 8, 32, 32, 1, 3, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 24, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 8>, + DeviceGemmXdlStreamK< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, 128, 192, 32, 4, 8, 32, 32, 3, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 32, 1, 4>, 8>, + DeviceGemmXdlStreamK< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, 128, 32, 64, 4, 8, 32, 32, 1, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, 1, 1, S<1, 16, 1, 8>, 8>, + DeviceGemmXdlStreamK< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, 128, 64, 32, 4, 8, 32, 32, 1, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 32, 1, 4>, 8>, + DeviceGemmXdlStreamK< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, 128, 32, 128, 4, 8, 32, 32, 1, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 1, 1, 1, S<1, 16, 1, 8>, 8>, + DeviceGemmXdlStreamK< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, 128, 128, 32, 4, 8, 32, 32, 2, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 32, 1, 4>, 8> + // clang-format on + >; + +void add_device_gemm_xdl_streamk_f16_f16_f16_mk_kn_mn_instances( + std::vector>>& + instances) +{ + add_device_operation_instances( + instances, device_gemm_xdl_streamk_f16_f16_f16_mk_kn_mn_generic_instances{}); + add_device_operation_instances(instances, + device_gemm_xdl_streamk_f16_f16_f16_mk_kn_mn_instances{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv1d_bwd_weight/device_grouped_conv1d_bwd_weight_xdl_gnwc_gkxc_gnwk_bf16_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv1d_bwd_weight/device_grouped_conv1d_bwd_weight_xdl_gnwc_gkxc_gnwk_bf16_instance.cpp index 05ba449246e8c6f68182eb707a3e9da0974e2078..25ea4f48cf8cd82bda19c28411d5745bc78713d7 100644 --- a/library/src/tensor_operation_instance/gpu/grouped_conv1d_bwd_weight/device_grouped_conv1d_bwd_weight_xdl_gnwc_gkxc_gnwk_bf16_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/grouped_conv1d_bwd_weight/device_grouped_conv1d_bwd_weight_xdl_gnwc_gkxc_gnwk_bf16_instance.cpp @@ -1,85 +1,21 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #include #include "ck/ck.hpp" #include "ck/tensor_operation/gpu/device/tensor_layout.hpp" -#include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_gnwc_gkxc_gnwk_xdl_cshuffle.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_xdl_cshuffle.hpp" #include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" #include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" +#include "ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_weight/device_grouped_conv_bwd_weight_xdl_instance.hpp" namespace ck { namespace tensor_operation { namespace device { namespace instance { -using BF16 = bhalf_t; -using F32 = float; - -template -using S = ck::Sequence; - -using GNWC = ck::tensor_layout::convolution::GNWC; -using GKXC = ck::tensor_layout::convolution::GKXC; -using GNWK = ck::tensor_layout::convolution::GNWK; - -using PassThrough = ck::tensor_operation::element_wise::PassThrough; - -static constexpr auto ConvBwdWeightDefault = - ck::tensor_operation::device::ConvolutionBackwardWeightSpecialization::Default; - -static constexpr auto ConvBwdWeightFilter1x1Stride1Pad0 = - ck::tensor_operation::device::ConvolutionBackwardWeightSpecialization::Filter1x1Stride1Pad0; - -// Compilation parameters for in[n, wi, c] * wei[k, x, c] = out[n, wo, k] -using device_grouped_conv1d_bwd_weight_xdl_c_shuffle_gnwc_gkxc_gnwk_bf16_f32_bf16_instances = - std::tuple< - // clang-format off - //#########################################| Num| InData| WeiData| OutData| AccData| In| Wei| Out| ConvBackward| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransfer| CBlockTransfer| - //#########################################| Dim| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise| Weight| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| ClusterLengths| ScalarPerVector| - //#########################################| Spatial| | | | | Operation| Operation| Operation| Specialization| | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| MBlock_MPerBlock| NWaveNPerXdl| - //#########################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | NBlock_NPerBlock| | - DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Xdl_CShuffle< 1, BF16, F32, BF16, F32, PassThrough, PassThrough, PassThrough, ConvBwdWeightDefault, 256, 256, 128, 4, 8, 32, 32, 4, 2, S<1, 4, 32, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, S<1, 4, 16, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 2, true, 1, 1, S<1, 32, 1, 8>, 4>, - DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Xdl_CShuffle< 1, BF16, F32, BF16, F32, PassThrough, PassThrough, PassThrough, ConvBwdWeightDefault, 256, 128, 256, 4, 8, 32, 32, 2, 4, S<1, 4, 16, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 2, true, S<1, 4, 32, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, 1, 1, S<1, 32, 1, 8>, 4>, - DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Xdl_CShuffle< 1, BF16, F32, BF16, F32, PassThrough, PassThrough, PassThrough, ConvBwdWeightDefault, 128, 128, 128, 4, 8, 32, 32, 4, 2, S<1, 4, 16, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, S<1, 4, 16, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, 1, 1, S<1, 32, 1, 4>, 4>, - DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Xdl_CShuffle< 1, BF16, F32, BF16, F32, PassThrough, PassThrough, PassThrough, ConvBwdWeightDefault, 256, 128, 128, 4, 8, 32, 32, 2, 2, S<1, 4, 16, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 2, true, S<1, 4, 16, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 2, true, 1, 1, S<1, 32, 1, 4>, 4>, - DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Xdl_CShuffle< 1, BF16, F32, BF16, F32, PassThrough, PassThrough, PassThrough, ConvBwdWeightDefault, 128, 128, 64, 4, 8, 32, 32, 2, 2, S<1, 4, 16, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, S<1, 4, 8, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 2, true, 1, 1, S<1, 32, 1, 4>, 4>, - DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Xdl_CShuffle< 1, BF16, F32, BF16, F32, PassThrough, PassThrough, PassThrough, ConvBwdWeightDefault, 128, 64, 128, 4, 8, 32, 32, 2, 2, S<1, 4, 8, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 2, true, S<1, 4, 16, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, 1, 1, S<1, 32, 1, 4>, 4>, - DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Xdl_CShuffle< 1, BF16, F32, BF16, F32, PassThrough, PassThrough, PassThrough, ConvBwdWeightDefault, 64, 64, 64, 4, 8, 32, 32, 2, 2, S<1, 4, 8, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, S<1, 4, 8, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, 1, 1, S<1, 16, 1, 4>, 4>, - DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Xdl_CShuffle< 1, BF16, F32, BF16, F32, PassThrough, PassThrough, PassThrough, ConvBwdWeightDefault, 256, 128, 64, 4, 8, 32, 32, 2, 1, S<1, 4, 16, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 2, true, S<1, 4, 8, 8>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 1, true, 1, 1, S<1, 32, 1, 4>, 4>, - DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Xdl_CShuffle< 1, BF16, F32, BF16, F32, PassThrough, PassThrough, PassThrough, ConvBwdWeightDefault, 256, 64, 128, 4, 8, 32, 32, 1, 2, S<1, 4, 8, 8>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 1, true, S<1, 4, 16, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 2, true, 1, 1, S<1, 32, 1, 4>, 4>, - DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Xdl_CShuffle< 1, BF16, F32, BF16, F32, PassThrough, PassThrough, PassThrough, ConvBwdWeightDefault, 128, 128, 32, 4, 8, 32, 32, 2, 1, S<1, 4, 16, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, S<1, 4, 4, 8>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 1, true, 1, 1, S<1, 32, 1, 4>, 4>, - DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Xdl_CShuffle< 1, BF16, F32, BF16, F32, PassThrough, PassThrough, PassThrough, ConvBwdWeightDefault, 128, 32, 128, 4, 8, 32, 32, 1, 2, S<1, 4, 4, 8>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 1, true, S<1, 4, 16, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, 1, 1, S<1, 32, 1, 4>, 4>, - DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Xdl_CShuffle< 1, BF16, F32, BF16, F32, PassThrough, PassThrough, PassThrough, ConvBwdWeightDefault, 64, 64, 32, 4, 8, 32, 32, 2, 1, S<1, 4, 8, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, S<1, 4, 4, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 2, true, 1, 1, S<1, 16, 1, 4>, 4>, - DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Xdl_CShuffle< 1, BF16, F32, BF16, F32, PassThrough, PassThrough, PassThrough, ConvBwdWeightDefault, 64, 32, 64, 4, 8, 32, 32, 1, 2, S<1, 4, 4, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 2, true, S<1, 4, 8, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, 1, 1, S<1, 16, 1, 4>, 4> - // clang-format on - >; - -using device_grouped_conv1d_bwd_weight_xdl_gnwc_gkxc_gnwk_1x1_s1_p0_bf16_f32_bf16_instances = - std::tuple< - // clang-format off - //#########################################| Num| InData| WeiData| OutData| AccData| In| Wei| Out| ConvBackward| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransfer| CBlockTransfer| - //#########################################| Dim| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise| Weight| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| ClusterLengths| ScalarPerVector| - //#########################################| Spatial| | | | | Operation| Operation| Operation| Specialization| | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| MBlock_MPerBlock| NWaveNPerXdl| - //#########################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | NBlock_NPerBlock| | - DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Xdl_CShuffle< 1, BF16, F32, BF16, F32, PassThrough, PassThrough, PassThrough, ConvBwdWeightFilter1x1Stride1Pad0, 256, 256, 128, 4, 8, 32, 32, 4, 2, S<1, 4, 32, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, S<1, 4, 16, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 2, true, 1, 1, S<1, 32, 1, 8>, 4>, - DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Xdl_CShuffle< 1, BF16, F32, BF16, F32, PassThrough, PassThrough, PassThrough, ConvBwdWeightFilter1x1Stride1Pad0, 256, 128, 256, 4, 8, 32, 32, 2, 4, S<1, 4, 16, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 2, true, S<1, 4, 32, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, 1, 1, S<1, 32, 1, 8>, 4>, - DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Xdl_CShuffle< 1, BF16, F32, BF16, F32, PassThrough, PassThrough, PassThrough, ConvBwdWeightFilter1x1Stride1Pad0, 128, 128, 128, 4, 8, 32, 32, 4, 2, S<1, 4, 16, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, S<1, 4, 16, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, 1, 1, S<1, 32, 1, 4>, 4>, - DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Xdl_CShuffle< 1, BF16, F32, BF16, F32, PassThrough, PassThrough, PassThrough, ConvBwdWeightFilter1x1Stride1Pad0, 256, 128, 128, 4, 8, 32, 32, 2, 2, S<1, 4, 16, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 2, true, S<1, 4, 16, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 2, true, 1, 1, S<1, 32, 1, 4>, 4>, - DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Xdl_CShuffle< 1, BF16, F32, BF16, F32, PassThrough, PassThrough, PassThrough, ConvBwdWeightFilter1x1Stride1Pad0, 128, 128, 64, 4, 8, 32, 32, 2, 2, S<1, 4, 16, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, S<1, 4, 8, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 2, true, 1, 1, S<1, 32, 1, 4>, 4>, - DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Xdl_CShuffle< 1, BF16, F32, BF16, F32, PassThrough, PassThrough, PassThrough, ConvBwdWeightFilter1x1Stride1Pad0, 128, 64, 128, 4, 8, 32, 32, 2, 2, S<1, 4, 8, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 2, true, S<1, 4, 16, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, 1, 1, S<1, 32, 1, 4>, 4>, - DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Xdl_CShuffle< 1, BF16, F32, BF16, F32, PassThrough, PassThrough, PassThrough, ConvBwdWeightFilter1x1Stride1Pad0, 64, 64, 64, 4, 8, 32, 32, 2, 2, S<1, 4, 8, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, S<1, 4, 8, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, 1, 1, S<1, 16, 1, 4>, 4>, - DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Xdl_CShuffle< 1, BF16, F32, BF16, F32, PassThrough, PassThrough, PassThrough, ConvBwdWeightFilter1x1Stride1Pad0, 256, 128, 64, 4, 8, 32, 32, 2, 1, S<1, 4, 16, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 2, true, S<1, 4, 8, 8>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 1, true, 1, 1, S<1, 32, 1, 4>, 4>, - DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Xdl_CShuffle< 1, BF16, F32, BF16, F32, PassThrough, PassThrough, PassThrough, ConvBwdWeightFilter1x1Stride1Pad0, 256, 64, 128, 4, 8, 32, 32, 1, 2, S<1, 4, 8, 8>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 1, true, S<1, 4, 16, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 2, true, 1, 1, S<1, 32, 1, 4>, 4>, - DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Xdl_CShuffle< 1, BF16, F32, BF16, F32, PassThrough, PassThrough, PassThrough, ConvBwdWeightFilter1x1Stride1Pad0, 128, 128, 32, 4, 8, 32, 32, 2, 1, S<1, 4, 16, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, S<1, 4, 4, 8>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 1, true, 1, 1, S<1, 32, 1, 4>, 4>, - DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Xdl_CShuffle< 1, BF16, F32, BF16, F32, PassThrough, PassThrough, PassThrough, ConvBwdWeightFilter1x1Stride1Pad0, 128, 32, 128, 4, 8, 32, 32, 1, 2, S<1, 4, 4, 8>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 1, true, S<1, 4, 16, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, 1, 1, S<1, 32, 1, 4>, 4>, - DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Xdl_CShuffle< 1, BF16, F32, BF16, F32, PassThrough, PassThrough, PassThrough, ConvBwdWeightFilter1x1Stride1Pad0, 64, 64, 32, 4, 8, 32, 32, 2, 1, S<1, 4, 8, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, S<1, 4, 4, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 2, true, 1, 1, S<1, 16, 1, 4>, 4>, - DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Xdl_CShuffle< 1, BF16, F32, BF16, F32, PassThrough, PassThrough, PassThrough, ConvBwdWeightFilter1x1Stride1Pad0, 64, 32, 64, 4, 8, 32, 32, 1, 2, S<1, 4, 4, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 2, true, S<1, 4, 8, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, 1, 1, S<1, 16, 1, 4>, 4> - // clang-format on - >; - void add_device_grouped_conv1d_bwd_weight_xdl_gnwc_gkxc_gnwk_bf16_f32_bf16_instances( std::vector>>& instances) { + // 1. Default add_device_operation_instances( instances, - device_grouped_conv1d_bwd_weight_xdl_c_shuffle_gnwc_gkxc_gnwk_bf16_f32_bf16_instances{}); - add_device_operation_instances( - instances, - device_grouped_conv1d_bwd_weight_xdl_gnwc_gkxc_gnwk_1x1_s1_p0_bf16_f32_bf16_instances{}); + device_grouped_conv_bwd_weight_xdl_c_shuffle_bf16_instances<1, + GNWC, + GKXC, + GNWK, + ConvBwdWeightDefault>{}); + // 2. Filter1x1Stride1Pad0 + add_device_operation_instances(instances, + device_grouped_conv_bwd_weight_xdl_c_shuffle_bf16_instances< + 1, + GNWC, + GKXC, + GNWK, + ConvBwdWeightFilter1x1Stride1Pad0>{}); } } // namespace instance diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv1d_bwd_weight/device_grouped_conv1d_bwd_weight_xdl_gnwc_gkxc_gnwk_f16_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv1d_bwd_weight/device_grouped_conv1d_bwd_weight_xdl_gnwc_gkxc_gnwk_f16_instance.cpp index 7a610a747cc9cc7555bd18b17fd499d19cd7c439..8065012f1258179601dc9befd9c11a778f6e2cf6 100644 --- a/library/src/tensor_operation_instance/gpu/grouped_conv1d_bwd_weight/device_grouped_conv1d_bwd_weight_xdl_gnwc_gkxc_gnwk_f16_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/grouped_conv1d_bwd_weight/device_grouped_conv1d_bwd_weight_xdl_gnwc_gkxc_gnwk_f16_instance.cpp @@ -1,84 +1,21 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #include #include "ck/ck.hpp" #include "ck/tensor_operation/gpu/device/tensor_layout.hpp" -#include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_gnwc_gkxc_gnwk_xdl_cshuffle.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_xdl_cshuffle.hpp" #include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" #include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" +#include "ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_weight/device_grouped_conv_bwd_weight_xdl_instance.hpp" namespace ck { namespace tensor_operation { namespace device { namespace instance { -using F16 = ck::half_t; -using F32 = float; - -template -using S = ck::Sequence; - -using GNWC = ck::tensor_layout::convolution::GNWC; -using GKXC = ck::tensor_layout::convolution::GKXC; -using GNWK = ck::tensor_layout::convolution::GNWK; - -using PassThrough = ck::tensor_operation::element_wise::PassThrough; - -static constexpr auto ConvBwdWeightDefault = - ck::tensor_operation::device::ConvolutionBackwardWeightSpecialization::Default; - -static constexpr auto ConvBwdWeightFilter1x1Stride1Pad0 = - ck::tensor_operation::device::ConvolutionBackwardWeightSpecialization::Filter1x1Stride1Pad0; - -// Compilation parameters for in[n, wi, c] * wei[k, x, c] = out[n, wo, k] -using device_grouped_conv1d_bwd_weight_xdl_c_shuffle_gnwc_gkxc_gnwk_f16_default_instances = - std::tuple< - // clang-format off - //#########################################| Num| InData| WeiData| OutData| AccData| In| Wei| Out| ConvBackward| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransfer| CBlockTransfer| - //#########################################| Dim| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise| Weight| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| ClusterLengths| ScalarPerVector| - //#########################################| Spatial| | | | | Operation| Operation| Operation| Specialization| | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| MBlock_MPerBlock| NWaveNPerXdl| - //#########################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | NBlock_NPerBlock| | - DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Xdl_CShuffle< 1, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvBwdWeightDefault, 256, 256, 128, 4, 8, 32, 32, 4, 2, S<1, 4, 32, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, S<1, 4, 16, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 2, true, 1, 1, S<1, 32, 1, 8>, 8>, - DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Xdl_CShuffle< 1, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvBwdWeightDefault, 256, 128, 256, 4, 8, 32, 32, 2, 4, S<1, 4, 16, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 2, true, S<1, 4, 32, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, 1, 1, S<1, 32, 1, 8>, 8>, - DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Xdl_CShuffle< 1, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvBwdWeightDefault, 128, 128, 128, 4, 8, 32, 32, 4, 2, S<1, 4, 16, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, S<1, 4, 16, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, 1, 1, S<1, 32, 1, 4>, 8>, - DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Xdl_CShuffle< 1, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvBwdWeightDefault, 256, 128, 128, 4, 8, 32, 32, 2, 2, S<1, 4, 16, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 2, true, S<1, 4, 16, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 2, true, 1, 1, S<1, 32, 1, 4>, 8>, - DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Xdl_CShuffle< 1, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvBwdWeightDefault, 128, 128, 64, 4, 8, 32, 32, 2, 2, S<1, 4, 16, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, S<1, 4, 8, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 2, true, 1, 1, S<1, 32, 1, 4>, 8>, - DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Xdl_CShuffle< 1, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvBwdWeightDefault, 128, 64, 128, 4, 8, 32, 32, 2, 2, S<1, 4, 8, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 2, true, S<1, 4, 16, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, 1, 1, S<1, 32, 1, 4>, 8>, - DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Xdl_CShuffle< 1, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvBwdWeightDefault, 64, 64, 64, 4, 8, 32, 32, 2, 2, S<1, 4, 8, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, S<1, 4, 8, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, 1, 1, S<1, 16, 1, 4>, 8>, - DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Xdl_CShuffle< 1, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvBwdWeightDefault, 256, 128, 64, 4, 8, 32, 32, 2, 1, S<1, 4, 16, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 2, true, S<1, 4, 8, 8>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 1, true, 1, 1, S<1, 32, 1, 4>, 8>, - DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Xdl_CShuffle< 1, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvBwdWeightDefault, 256, 64, 128, 4, 8, 32, 32, 1, 2, S<1, 4, 8, 8>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 1, true, S<1, 4, 16, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 2, true, 1, 1, S<1, 32, 1, 4>, 8>, - DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Xdl_CShuffle< 1, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvBwdWeightDefault, 128, 128, 32, 4, 8, 32, 32, 2, 1, S<1, 4, 16, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, S<1, 4, 4, 8>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 1, true, 1, 1, S<1, 32, 1, 4>, 8>, - DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Xdl_CShuffle< 1, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvBwdWeightDefault, 128, 32, 128, 4, 8, 32, 32, 1, 2, S<1, 4, 4, 8>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 1, true, S<1, 4, 16, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, 1, 1, S<1, 32, 1, 4>, 8>, - DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Xdl_CShuffle< 1, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvBwdWeightDefault, 64, 64, 32, 4, 8, 32, 32, 2, 1, S<1, 4, 8, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, S<1, 4, 4, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 2, true, 1, 1, S<1, 16, 1, 4>, 8>, - DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Xdl_CShuffle< 1, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvBwdWeightDefault, 64, 32, 64, 4, 8, 32, 32, 1, 2, S<1, 4, 4, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 2, true, S<1, 4, 8, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, 1, 1, S<1, 16, 1, 4>, 8> - // clang-format on - >; - -using device_grouped_conv1d_bwd_weight_xdl_gnwc_gkxc_gnwk_1x1_s1_p0_f16_instances = std::tuple< - // clang-format off - //#########################################| Num| InData| WeiData| OutData| AccData| In| Wei| Out| ConvBackward| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransfer| CBlockTransfer| - //#########################################| Dim| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise| Weight| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| ClusterLengths| ScalarPerVector| - //#########################################| Spatial| | | | | Operation| Operation| Operation| Specialization| | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| MBlock_MPerBlock| NWaveNPerXdl| - //#########################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | NBlock_NPerBlock| | - DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Xdl_CShuffle< 1, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvBwdWeightFilter1x1Stride1Pad0, 256, 256, 128, 4, 8, 32, 32, 4, 2, S<1, 4, 32, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, S<1, 4, 16, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 2, true, 1, 1, S<1, 32, 1, 8>, 8>, - DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Xdl_CShuffle< 1, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvBwdWeightFilter1x1Stride1Pad0, 256, 128, 256, 4, 8, 32, 32, 2, 4, S<1, 4, 16, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 2, true, S<1, 4, 32, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, 1, 1, S<1, 32, 1, 8>, 8>, - DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Xdl_CShuffle< 1, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvBwdWeightFilter1x1Stride1Pad0, 128, 128, 128, 4, 8, 32, 32, 4, 2, S<1, 4, 16, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, S<1, 4, 16, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, 1, 1, S<1, 32, 1, 4>, 8>, - DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Xdl_CShuffle< 1, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvBwdWeightFilter1x1Stride1Pad0, 256, 128, 128, 4, 8, 32, 32, 2, 2, S<1, 4, 16, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 2, true, S<1, 4, 16, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 2, true, 1, 1, S<1, 32, 1, 4>, 8>, - DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Xdl_CShuffle< 1, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvBwdWeightFilter1x1Stride1Pad0, 128, 128, 64, 4, 8, 32, 32, 2, 2, S<1, 4, 16, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, S<1, 4, 8, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 2, true, 1, 1, S<1, 32, 1, 4>, 8>, - DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Xdl_CShuffle< 1, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvBwdWeightFilter1x1Stride1Pad0, 128, 64, 128, 4, 8, 32, 32, 2, 2, S<1, 4, 8, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 2, true, S<1, 4, 16, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, 1, 1, S<1, 32, 1, 4>, 8>, - DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Xdl_CShuffle< 1, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvBwdWeightFilter1x1Stride1Pad0, 64, 64, 64, 4, 8, 32, 32, 2, 2, S<1, 4, 8, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, S<1, 4, 8, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, 1, 1, S<1, 16, 1, 4>, 8>, - DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Xdl_CShuffle< 1, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvBwdWeightFilter1x1Stride1Pad0, 256, 128, 64, 4, 8, 32, 32, 2, 1, S<1, 4, 16, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 2, true, S<1, 4, 8, 8>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 1, true, 1, 1, S<1, 32, 1, 4>, 8>, - DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Xdl_CShuffle< 1, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvBwdWeightFilter1x1Stride1Pad0, 256, 64, 128, 4, 8, 32, 32, 1, 2, S<1, 4, 8, 8>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 1, true, S<1, 4, 16, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 2, true, 1, 1, S<1, 32, 1, 4>, 8>, - DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Xdl_CShuffle< 1, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvBwdWeightFilter1x1Stride1Pad0, 128, 128, 32, 4, 8, 32, 32, 2, 1, S<1, 4, 16, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, S<1, 4, 4, 8>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 1, true, 1, 1, S<1, 32, 1, 4>, 8>, - DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Xdl_CShuffle< 1, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvBwdWeightFilter1x1Stride1Pad0, 128, 32, 128, 4, 8, 32, 32, 1, 2, S<1, 4, 4, 8>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 1, true, S<1, 4, 16, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, 1, 1, S<1, 32, 1, 4>, 8>, - DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Xdl_CShuffle< 1, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvBwdWeightFilter1x1Stride1Pad0, 64, 64, 32, 4, 8, 32, 32, 2, 1, S<1, 4, 8, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, S<1, 4, 4, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 2, true, 1, 1, S<1, 16, 1, 4>, 8>, - DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Xdl_CShuffle< 1, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvBwdWeightFilter1x1Stride1Pad0, 64, 32, 64, 4, 8, 32, 32, 1, 2, S<1, 4, 4, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 2, true, S<1, 4, 8, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, 1, 1, S<1, 16, 1, 4>, 8> - // clang-format on - >; - void add_device_grouped_conv1d_bwd_weight_xdl_gnwc_gkxc_gnwk_f16_instances( std::vector>>& instances) { + // 1. Default add_device_operation_instances( instances, - device_grouped_conv1d_bwd_weight_xdl_c_shuffle_gnwc_gkxc_gnwk_f16_default_instances{}); - add_device_operation_instances( - instances, device_grouped_conv1d_bwd_weight_xdl_gnwc_gkxc_gnwk_1x1_s1_p0_f16_instances{}); + device_grouped_conv_bwd_weight_xdl_c_shuffle_f16_instances<1, + GNWC, + GKXC, + GNWK, + ConvBwdWeightDefault>{}); + // 2. Filter1x1Stride1Pad0 + add_device_operation_instances(instances, + device_grouped_conv_bwd_weight_xdl_c_shuffle_f16_instances< + 1, + GNWC, + GKXC, + GNWK, + ConvBwdWeightFilter1x1Stride1Pad0>{}); } } // namespace instance diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv1d_bwd_weight/device_grouped_conv1d_bwd_weight_xdl_gnwc_gkxc_gnwk_f32_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv1d_bwd_weight/device_grouped_conv1d_bwd_weight_xdl_gnwc_gkxc_gnwk_f32_instance.cpp index 90e074f052c6f935bcbf862250670289e0634617..c70a54c2dc5664ff863e83f6072361e69c9f86b2 100644 --- a/library/src/tensor_operation_instance/gpu/grouped_conv1d_bwd_weight/device_grouped_conv1d_bwd_weight_xdl_gnwc_gkxc_gnwk_f32_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/grouped_conv1d_bwd_weight/device_grouped_conv1d_bwd_weight_xdl_gnwc_gkxc_gnwk_f32_instance.cpp @@ -1,83 +1,21 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #include #include "ck/ck.hpp" #include "ck/tensor_operation/gpu/device/tensor_layout.hpp" -#include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_gnwc_gkxc_gnwk_xdl_cshuffle.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_xdl_cshuffle.hpp" #include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" #include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" +#include "ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_weight/device_grouped_conv_bwd_weight_xdl_instance.hpp" namespace ck { namespace tensor_operation { namespace device { namespace instance { -using F32 = float; - -template -using S = ck::Sequence; - -using GNWC = ck::tensor_layout::convolution::GNWC; -using GKXC = ck::tensor_layout::convolution::GKXC; -using GNWK = ck::tensor_layout::convolution::GNWK; - -using PassThrough = ck::tensor_operation::element_wise::PassThrough; - -static constexpr auto ConvBwdWeightDefault = - ck::tensor_operation::device::ConvolutionBackwardWeightSpecialization::Default; - -static constexpr auto ConvBwdWeightFilter1x1Stride1Pad0 = - ck::tensor_operation::device::ConvolutionBackwardWeightSpecialization::Filter1x1Stride1Pad0; - -// Compilation parameters for in[n, wi, c] * wei[k, x, c] = out[n, wo, k] -using device_grouped_conv1d_bwd_weight_xdl_c_shuffle_gnwc_gkxc_gnwk_f32_default_instances = - std::tuple< - // clang-format off - //#########################################| Num| InData| WeiData| OutData| AccData| In| Wei| Out| ConvBackward| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransfer| CBlockTransfer| - //#########################################| Dim| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise| Weight| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| ClusterLengths| ScalarPerVector| - //#########################################| Spatial| | | | | Operation| Operation| Operation| Specialization| | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| MBlock_MPerBlock| NWaveNPerXdl| - //#########################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | NBlock_NPerBlock| | - DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Xdl_CShuffle< 1, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvBwdWeightDefault, 256, 256, 128, 4, 4, 32, 32, 4, 2, S<1, 4, 64, 1>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 4, 4, true, S<1, 4, 32, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 4, 2, true, 1, 1, S<1, 32, 1, 8>, 4>, - DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Xdl_CShuffle< 1, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvBwdWeightDefault, 256, 128, 256, 4, 4, 32, 32, 2, 4, S<1, 4, 32, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 4, 2, true, S<1, 4, 64, 1>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 4, 4, true, 1, 1, S<1, 32, 1, 8>, 4>, - DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Xdl_CShuffle< 1, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvBwdWeightDefault, 128, 128, 128, 4, 4, 32, 32, 4, 2, S<1, 4, 32, 1>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 4, 4, true, S<1, 4, 32, 1>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 4, 4, true, 1, 1, S<1, 32, 1, 4>, 4>, - DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Xdl_CShuffle< 1, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvBwdWeightDefault, 256, 128, 128, 4, 4, 32, 32, 2, 2, S<1, 4, 32, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 4, 2, true, S<1, 4, 32, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 4, 2, true, 1, 1, S<1, 32, 1, 4>, 4>, - DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Xdl_CShuffle< 1, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvBwdWeightDefault, 128, 128, 64, 4, 4, 32, 32, 2, 2, S<1, 4, 32, 1>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 4, 4, true, S<1, 4, 16, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 4, 2, true, 1, 1, S<1, 32, 1, 4>, 4>, - DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Xdl_CShuffle< 1, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvBwdWeightDefault, 128, 64, 128, 4, 4, 32, 32, 2, 2, S<1, 4, 16, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 4, 2, true, S<1, 4, 32, 1>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 4, 4, true, 1, 1, S<1, 32, 1, 4>, 4>, - DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Xdl_CShuffle< 1, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvBwdWeightDefault, 64, 64, 64, 4, 4, 32, 32, 2, 2, S<1, 4, 16, 1>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 4, 4, true, S<1, 4, 16, 1>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 4, 4, true, 1, 1, S<1, 16, 1, 4>, 4>, - DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Xdl_CShuffle< 1, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvBwdWeightDefault, 256, 128, 64, 4, 4, 32, 32, 2, 1, S<1, 4, 32, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 4, 2, true, S<1, 4, 16, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 4, 1, true, 1, 1, S<1, 32, 1, 4>, 4>, - DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Xdl_CShuffle< 1, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvBwdWeightDefault, 256, 64, 128, 4, 4, 32, 32, 1, 2, S<1, 4, 16, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 4, 1, true, S<1, 4, 32, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 4, 2, true, 1, 1, S<1, 32, 1, 4>, 4>, - DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Xdl_CShuffle< 1, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvBwdWeightDefault, 128, 128, 32, 4, 4, 32, 32, 2, 1, S<1, 4, 32, 1>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 4, 4, true, S<1, 4, 8, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 4, 1, true, 1, 1, S<1, 32, 1, 4>, 4>, - DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Xdl_CShuffle< 1, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvBwdWeightDefault, 128, 32, 128, 4, 4, 32, 32, 1, 2, S<1, 4, 8, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 4, 1, true, S<1, 4, 32, 1>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 4, 4, true, 1, 1, S<1, 32, 1, 4>, 4>, - DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Xdl_CShuffle< 1, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvBwdWeightDefault, 64, 64, 32, 4, 4, 32, 32, 2, 1, S<1, 4, 16, 1>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 4, 4, true, S<1, 4, 8, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 4, 2, true, 1, 1, S<1, 16, 1, 4>, 4>, - DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Xdl_CShuffle< 1, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvBwdWeightDefault, 64, 32, 64, 4, 4, 32, 32, 1, 2, S<1, 4, 8, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 4, 2, true, S<1, 4, 16, 1>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 4, 4, true, 1, 1, S<1, 16, 1, 4>, 4> - // clang-format on - >; - -using device_grouped_conv1d_bwd_weight_xdl_gnwc_gkxc_gnwk_1x1_s1_p0_f32_instances = std::tuple< - // clang-format off - //#########################################| Num| InData| WeiData| OutData| AccData| In| Wei| Out| ConvBackward| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransfer| CBlockTransfer| - //#########################################| Dim| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise| Weight| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| ClusterLengths| ScalarPerVector| - //#########################################| Spatial| | | | | Operation| Operation| Operation| Specialization| | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| MBlock_MPerBlock| NWaveNPerXdl| - //#########################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | NBlock_NPerBlock| | - DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Xdl_CShuffle< 1, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvBwdWeightFilter1x1Stride1Pad0, 256, 256, 128, 4, 4, 32, 32, 4, 2, S<1, 4, 64, 1>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 4, 4, true, S<1, 4, 32, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 4, 2, true, 1, 1, S<1, 32, 1, 8>, 4>, - DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Xdl_CShuffle< 1, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvBwdWeightFilter1x1Stride1Pad0, 256, 128, 256, 4, 4, 32, 32, 2, 4, S<1, 4, 32, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 4, 2, true, S<1, 4, 64, 1>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 4, 4, true, 1, 1, S<1, 32, 1, 8>, 4>, - DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Xdl_CShuffle< 1, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvBwdWeightFilter1x1Stride1Pad0, 128, 128, 128, 4, 4, 32, 32, 4, 2, S<1, 4, 32, 1>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 4, 4, true, S<1, 4, 32, 1>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 4, 4, true, 1, 1, S<1, 32, 1, 4>, 4>, - DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Xdl_CShuffle< 1, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvBwdWeightFilter1x1Stride1Pad0, 256, 128, 128, 4, 4, 32, 32, 2, 2, S<1, 4, 32, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 4, 2, true, S<1, 4, 32, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 4, 2, true, 1, 1, S<1, 32, 1, 4>, 4>, - DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Xdl_CShuffle< 1, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvBwdWeightFilter1x1Stride1Pad0, 128, 128, 64, 4, 4, 32, 32, 2, 2, S<1, 4, 32, 1>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 4, 4, true, S<1, 4, 16, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 4, 2, true, 1, 1, S<1, 32, 1, 4>, 4>, - DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Xdl_CShuffle< 1, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvBwdWeightFilter1x1Stride1Pad0, 128, 64, 128, 4, 4, 32, 32, 2, 2, S<1, 4, 16, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 4, 2, true, S<1, 4, 32, 1>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 4, 4, true, 1, 1, S<1, 32, 1, 4>, 4>, - DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Xdl_CShuffle< 1, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvBwdWeightFilter1x1Stride1Pad0, 64, 64, 64, 4, 4, 32, 32, 2, 2, S<1, 4, 16, 1>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 4, 4, true, S<1, 4, 16, 1>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 4, 4, true, 1, 1, S<1, 16, 1, 4>, 4>, - DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Xdl_CShuffle< 1, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvBwdWeightFilter1x1Stride1Pad0, 256, 128, 64, 4, 4, 32, 32, 2, 1, S<1, 4, 32, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 4, 2, true, S<1, 4, 16, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 4, 1, true, 1, 1, S<1, 32, 1, 4>, 4>, - DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Xdl_CShuffle< 1, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvBwdWeightFilter1x1Stride1Pad0, 256, 64, 128, 4, 4, 32, 32, 1, 2, S<1, 4, 16, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 4, 1, true, S<1, 4, 32, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 4, 2, true, 1, 1, S<1, 32, 1, 4>, 4>, - DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Xdl_CShuffle< 1, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvBwdWeightFilter1x1Stride1Pad0, 128, 128, 32, 4, 4, 32, 32, 2, 1, S<1, 4, 32, 1>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 4, 4, true, S<1, 4, 8, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 4, 1, true, 1, 1, S<1, 32, 1, 4>, 4>, - DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Xdl_CShuffle< 1, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvBwdWeightFilter1x1Stride1Pad0, 128, 32, 128, 4, 4, 32, 32, 1, 2, S<1, 4, 8, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 4, 1, true, S<1, 4, 32, 1>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 4, 4, true, 1, 1, S<1, 32, 1, 4>, 4>, - DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Xdl_CShuffle< 1, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvBwdWeightFilter1x1Stride1Pad0, 64, 64, 32, 4, 4, 32, 32, 2, 1, S<1, 4, 16, 1>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 4, 4, true, S<1, 4, 8, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 4, 2, true, 1, 1, S<1, 16, 1, 4>, 4>, - DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Xdl_CShuffle< 1, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvBwdWeightFilter1x1Stride1Pad0, 64, 32, 64, 4, 4, 32, 32, 1, 2, S<1, 4, 8, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 4, 2, true, S<1, 4, 16, 1>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 4, 4, true, 1, 1, S<1, 16, 1, 4>, 4> - // clang-format on - >; - void add_device_grouped_conv1d_bwd_weight_xdl_gnwc_gkxc_gnwk_f32_instances( std::vector>>& instances) { + // 1. Default add_device_operation_instances( instances, - device_grouped_conv1d_bwd_weight_xdl_c_shuffle_gnwc_gkxc_gnwk_f32_default_instances{}); - add_device_operation_instances( - instances, device_grouped_conv1d_bwd_weight_xdl_gnwc_gkxc_gnwk_1x1_s1_p0_f32_instances{}); + device_grouped_conv_bwd_weight_xdl_c_shuffle_f32_instances<1, + GNWC, + GKXC, + GNWK, + ConvBwdWeightDefault>{}); + // 2. Filter1x1Stride1Pad0 + add_device_operation_instances(instances, + device_grouped_conv_bwd_weight_xdl_c_shuffle_f32_instances< + 1, + GNWC, + GKXC, + GNWK, + ConvBwdWeightFilter1x1Stride1Pad0>{}); } } // namespace instance diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv1d_fwd/device_grouped_conv1d_fwd_xdl_gnwc_gkxc_gnwk_bf16_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv1d_fwd/device_grouped_conv1d_fwd_xdl_gnwc_gkxc_gnwk_bf16_instance.cpp index 74aebf1031d5b02db54fece32fa693e0c532287d..032ebc1eff914bafdc12099c2b53dcc7664e7f27 100644 --- a/library/src/tensor_operation_instance/gpu/grouped_conv1d_fwd/device_grouped_conv1d_fwd_xdl_gnwc_gkxc_gnwk_bf16_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/grouped_conv1d_fwd/device_grouped_conv1d_fwd_xdl_gnwc_gkxc_gnwk_bf16_instance.cpp @@ -1,15 +1,7 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. - -#include - -#include "ck/ck.hpp" -#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" -#include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_d_xdl_cshuffle.hpp" -#include "ck/tensor_operation/gpu/device/convolution_forward_specialization.hpp" -#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" -#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. +#include "ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_instance.hpp" #include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" namespace ck { @@ -17,94 +9,6 @@ namespace tensor_operation { namespace device { namespace instance { -using BF16 = ck::bhalf_t; -using F32 = float; - -using Empty_Tuple = ck::Tuple<>; - -template -using S = ck::Sequence; - -using GNWC = ck::tensor_layout::convolution::GNWC; -using GKXC = ck::tensor_layout::convolution::GKXC; -using GNWK = ck::tensor_layout::convolution::GNWK; - -using PassThrough = ck::tensor_operation::element_wise::PassThrough; - -static constexpr auto ConvFwdDefault = - ck::tensor_operation::device::ConvolutionForwardSpecialization::Default; - -static constexpr auto ConvFwd1x1P0 = - ck::tensor_operation::device::ConvolutionForwardSpecialization::Filter1x1Pad0; - -static constexpr auto ConvFwd1x1S1P0 = - ck::tensor_operation::device::ConvolutionForwardSpecialization::Filter1x1Stride1Pad0; - -static constexpr auto GemmMNKPadding = ck::tensor_operation::device::GemmSpecialization::MNKPadding; - -// Compilation parameters for in[g, n, wi, c] * wei[g, k, x, c] = out[g, n, wo, k] -using device_grouped_conv1d_fwd_xdl_gnwc_gkxc_gnwk_bf16_instances = - std::tuple< - // clang-format off - // Default - //########################################| NumDim| A| B| Ds| E| AData| BData| AccData| CShuffle| Ds| EData| A| B| CDE| ConvForward| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| - //########################################| Spatial| Layout| Layout| Layout| Layout| Type| Type| Type| DataType| DataType| Type| Elementwise| Elementwise| Elementwise| Specialization| Specialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| - //########################################| | | | | | | | | | | | Operation| Operation| Operation| | | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| - //########################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | - DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 1, GNWC, GKXC, Empty_Tuple, GNWK, BF16, BF16, F32, BF16, Empty_Tuple, BF16, PassThrough, PassThrough, PassThrough, ConvFwdDefault, GemmMNKPadding, 1, 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, - DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 1, GNWC, GKXC, Empty_Tuple, GNWK, BF16, BF16, F32, BF16, Empty_Tuple, BF16, PassThrough, PassThrough, PassThrough, ConvFwdDefault, GemmMNKPadding, 1, 256, 128, 256, 32, 8, 8, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, - DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 1, GNWC, GKXC, Empty_Tuple, GNWK, BF16, BF16, F32, BF16, Empty_Tuple, BF16, PassThrough, PassThrough, PassThrough, ConvFwdDefault, GemmMNKPadding, 1, 128, 128, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 8>, - DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 1, GNWC, GKXC, Empty_Tuple, GNWK, BF16, BF16, F32, BF16, Empty_Tuple, BF16, PassThrough, PassThrough, PassThrough, ConvFwdDefault, GemmMNKPadding, 1, 256, 128, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, - DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 1, GNWC, GKXC, Empty_Tuple, GNWK, BF16, BF16, F32, BF16, Empty_Tuple, BF16, PassThrough, PassThrough, PassThrough, ConvFwdDefault, GemmMNKPadding, 1, 128, 128, 64, 32, 8, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 8>, - DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 1, GNWC, GKXC, Empty_Tuple, GNWK, BF16, BF16, F32, BF16, Empty_Tuple, BF16, PassThrough, PassThrough, PassThrough, ConvFwdDefault, GemmMNKPadding, 1, 128, 64, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 8>, - DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 1, GNWC, GKXC, Empty_Tuple, GNWK, BF16, BF16, F32, BF16, Empty_Tuple, BF16, PassThrough, PassThrough, PassThrough, ConvFwdDefault, GemmMNKPadding, 1, 64, 64, 64, 32, 8, 8, 32, 32, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 8>, - DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 1, GNWC, GKXC, Empty_Tuple, GNWK, BF16, BF16, F32, BF16, Empty_Tuple, BF16, PassThrough, PassThrough, PassThrough, ConvFwdDefault, GemmMNKPadding, 1, 256, 128, 64, 32, 8, 8, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, - DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 1, GNWC, GKXC, Empty_Tuple, GNWK, BF16, BF16, F32, BF16, Empty_Tuple, BF16, PassThrough, PassThrough, PassThrough, ConvFwdDefault, GemmMNKPadding, 1, 256, 64, 128, 32, 8, 8, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, - DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 1, GNWC, GKXC, Empty_Tuple, GNWK, BF16, BF16, F32, BF16, Empty_Tuple, BF16, PassThrough, PassThrough, PassThrough, ConvFwdDefault, GemmMNKPadding, 1, 128, 128, 32, 32, 8, 8, 32, 32, 2, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 8>, - DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 1, GNWC, GKXC, Empty_Tuple, GNWK, BF16, BF16, F32, BF16, Empty_Tuple, BF16, PassThrough, PassThrough, PassThrough, ConvFwdDefault, GemmMNKPadding, 1, 128, 32, 128, 32, 8, 8, 32, 32, 1, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 8>, - DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 1, GNWC, GKXC, Empty_Tuple, GNWK, BF16, BF16, F32, BF16, Empty_Tuple, BF16, PassThrough, PassThrough, PassThrough, ConvFwdDefault, GemmMNKPadding, 1, 64, 64, 32, 32, 8, 8, 32, 32, 2, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 8>, - DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 1, GNWC, GKXC, Empty_Tuple, GNWK, BF16, BF16, F32, BF16, Empty_Tuple, BF16, PassThrough, PassThrough, PassThrough, ConvFwdDefault, GemmMNKPadding, 1, 64, 32, 64, 32, 8, 8, 32, 32, 1, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 8>, - - // Filter1x1Pad0 - //########################################| NumDim| A| B| Ds| E| AData| BData| AccData| CShuffle| Ds| EData| A| B| CDE| ConvForward| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| - //########################################| Spatial| Layout| Layout| Layout| Layout| Type| Type| Type| DataType| DataType| Type| Elementwise| Elementwise| Elementwise| Specialization| Specialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| - //########################################| | | | | | | | | | | | Operation| Operation| Operation| | | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| - //########################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | - DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 1, GNWC, GKXC, Empty_Tuple, GNWK, BF16, BF16, F32, BF16, Empty_Tuple, BF16, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, GemmMNKPadding, 1, 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, - DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 1, GNWC, GKXC, Empty_Tuple, GNWK, BF16, BF16, F32, BF16, Empty_Tuple, BF16, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, GemmMNKPadding, 1, 256, 128, 256, 32, 8, 8, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, - DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 1, GNWC, GKXC, Empty_Tuple, GNWK, BF16, BF16, F32, BF16, Empty_Tuple, BF16, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, GemmMNKPadding, 1, 128, 128, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 8>, - DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 1, GNWC, GKXC, Empty_Tuple, GNWK, BF16, BF16, F32, BF16, Empty_Tuple, BF16, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, GemmMNKPadding, 1, 256, 128, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, - DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 1, GNWC, GKXC, Empty_Tuple, GNWK, BF16, BF16, F32, BF16, Empty_Tuple, BF16, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, GemmMNKPadding, 1, 128, 128, 64, 32, 8, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 8>, - DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 1, GNWC, GKXC, Empty_Tuple, GNWK, BF16, BF16, F32, BF16, Empty_Tuple, BF16, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, GemmMNKPadding, 1, 128, 64, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 8>, - DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 1, GNWC, GKXC, Empty_Tuple, GNWK, BF16, BF16, F32, BF16, Empty_Tuple, BF16, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, GemmMNKPadding, 1, 64, 64, 64, 32, 8, 8, 32, 32, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 8>, - DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 1, GNWC, GKXC, Empty_Tuple, GNWK, BF16, BF16, F32, BF16, Empty_Tuple, BF16, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, GemmMNKPadding, 1, 256, 128, 64, 32, 8, 8, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, - DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 1, GNWC, GKXC, Empty_Tuple, GNWK, BF16, BF16, F32, BF16, Empty_Tuple, BF16, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, GemmMNKPadding, 1, 256, 64, 128, 32, 8, 8, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, - DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 1, GNWC, GKXC, Empty_Tuple, GNWK, BF16, BF16, F32, BF16, Empty_Tuple, BF16, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, GemmMNKPadding, 1, 128, 128, 32, 32, 8, 8, 32, 32, 2, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 8>, - DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 1, GNWC, GKXC, Empty_Tuple, GNWK, BF16, BF16, F32, BF16, Empty_Tuple, BF16, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, GemmMNKPadding, 1, 128, 32, 128, 32, 8, 8, 32, 32, 1, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 8>, - DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 1, GNWC, GKXC, Empty_Tuple, GNWK, BF16, BF16, F32, BF16, Empty_Tuple, BF16, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, GemmMNKPadding, 1, 64, 64, 32, 32, 8, 8, 32, 32, 2, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 8>, - DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 1, GNWC, GKXC, Empty_Tuple, GNWK, BF16, BF16, F32, BF16, Empty_Tuple, BF16, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, GemmMNKPadding, 1, 64, 32, 64, 32, 8, 8, 32, 32, 1, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 8>, - - // Filter1x1Stride1Pad0 - //########################################| NumDim| A| B| Ds| E| AData| BData| AccData| CShuffle| Ds| EData| A| B| CDE| ConvForward| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| - //########################################| Spatial| Layout| Layout| Layout| Layout| Type| Type| Type| DataType| DataType| Type| Elementwise| Elementwise| Elementwise| Specialization| Specialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| - //########################################| | | | | | | | | | | | Operation| Operation| Operation| | | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| - //########################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | - DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 1, GNWC, GKXC, Empty_Tuple, GNWK, BF16, BF16, F32, BF16, Empty_Tuple, BF16, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, GemmMNKPadding, 1, 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, - DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 1, GNWC, GKXC, Empty_Tuple, GNWK, BF16, BF16, F32, BF16, Empty_Tuple, BF16, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, GemmMNKPadding, 1, 256, 128, 256, 32, 8, 8, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, - DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 1, GNWC, GKXC, Empty_Tuple, GNWK, BF16, BF16, F32, BF16, Empty_Tuple, BF16, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, GemmMNKPadding, 1, 128, 128, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 8>, - DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 1, GNWC, GKXC, Empty_Tuple, GNWK, BF16, BF16, F32, BF16, Empty_Tuple, BF16, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, GemmMNKPadding, 1, 256, 128, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, - DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 1, GNWC, GKXC, Empty_Tuple, GNWK, BF16, BF16, F32, BF16, Empty_Tuple, BF16, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, GemmMNKPadding, 1, 128, 128, 64, 32, 8, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 8>, - DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 1, GNWC, GKXC, Empty_Tuple, GNWK, BF16, BF16, F32, BF16, Empty_Tuple, BF16, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, GemmMNKPadding, 1, 128, 64, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 8>, - DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 1, GNWC, GKXC, Empty_Tuple, GNWK, BF16, BF16, F32, BF16, Empty_Tuple, BF16, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, GemmMNKPadding, 1, 64, 64, 64, 32, 8, 8, 32, 32, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 8>, - DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 1, GNWC, GKXC, Empty_Tuple, GNWK, BF16, BF16, F32, BF16, Empty_Tuple, BF16, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, GemmMNKPadding, 1, 256, 128, 64, 32, 8, 8, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, - DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 1, GNWC, GKXC, Empty_Tuple, GNWK, BF16, BF16, F32, BF16, Empty_Tuple, BF16, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, GemmMNKPadding, 1, 256, 64, 128, 32, 8, 8, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, - DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 1, GNWC, GKXC, Empty_Tuple, GNWK, BF16, BF16, F32, BF16, Empty_Tuple, BF16, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, GemmMNKPadding, 1, 128, 128, 32, 32, 8, 8, 32, 32, 2, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 8>, - DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 1, GNWC, GKXC, Empty_Tuple, GNWK, BF16, BF16, F32, BF16, Empty_Tuple, BF16, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, GemmMNKPadding, 1, 128, 32, 128, 32, 8, 8, 32, 32, 1, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 8>, - DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 1, GNWC, GKXC, Empty_Tuple, GNWK, BF16, BF16, F32, BF16, Empty_Tuple, BF16, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, GemmMNKPadding, 1, 64, 64, 32, 32, 8, 8, 32, 32, 2, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 8>, - DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 1, GNWC, GKXC, Empty_Tuple, GNWK, BF16, BF16, F32, BF16, Empty_Tuple, BF16, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, GemmMNKPadding, 1, 64, 32, 64, 32, 8, 8, 32, 32, 1, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 8> - // clang-format on - >; - void add_device_grouped_conv1d_fwd_xdl_gnwc_gkxc_gnwk_bf16_instances( std::vector>>& instances) { add_device_operation_instances(instances, - device_grouped_conv1d_fwd_xdl_gnwc_gkxc_gnwk_bf16_instances{}); + device_grouped_conv_fwd_xdl_bf16_instances<1, + GNWC, + GKXC, + Empty_Tuple, + GNWK, + ConvFwdDefault>{}); + add_device_operation_instances(instances, + device_grouped_conv_fwd_xdl_bf16_instances<1, + GNWC, + GKXC, + Empty_Tuple, + GNWK, + ConvFwd1x1P0>{}); + add_device_operation_instances(instances, + device_grouped_conv_fwd_xdl_bf16_instances<1, + GNWC, + GKXC, + Empty_Tuple, + GNWK, + ConvFwd1x1S1P0>{}); } } // namespace instance diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv1d_fwd/device_grouped_conv1d_fwd_xdl_gnwc_gkxc_gnwk_f16_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv1d_fwd/device_grouped_conv1d_fwd_xdl_gnwc_gkxc_gnwk_f16_instance.cpp index 361ea8f4ee9ff8144fd2392754fdaa53f380e40b..3a0ddb736f9dcc90b2f4021faabd13eb9ca5bb2e 100644 --- a/library/src/tensor_operation_instance/gpu/grouped_conv1d_fwd/device_grouped_conv1d_fwd_xdl_gnwc_gkxc_gnwk_f16_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/grouped_conv1d_fwd/device_grouped_conv1d_fwd_xdl_gnwc_gkxc_gnwk_f16_instance.cpp @@ -1,15 +1,7 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. - -#include - -#include "ck/ck.hpp" -#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" -#include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_d_xdl_cshuffle.hpp" -#include "ck/tensor_operation/gpu/device/convolution_forward_specialization.hpp" -#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" -#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. +#include "ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_instance.hpp" #include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" namespace ck { @@ -17,94 +9,6 @@ namespace tensor_operation { namespace device { namespace instance { -using F16 = ck::half_t; -using F32 = float; - -using Empty_Tuple = ck::Tuple<>; - -template -using S = ck::Sequence; - -using GNWC = ck::tensor_layout::convolution::GNWC; -using GKXC = ck::tensor_layout::convolution::GKXC; -using GNWK = ck::tensor_layout::convolution::GNWK; - -using PassThrough = ck::tensor_operation::element_wise::PassThrough; - -static constexpr auto ConvFwdDefault = - ck::tensor_operation::device::ConvolutionForwardSpecialization::Default; - -static constexpr auto ConvFwd1x1P0 = - ck::tensor_operation::device::ConvolutionForwardSpecialization::Filter1x1Pad0; - -static constexpr auto ConvFwd1x1S1P0 = - ck::tensor_operation::device::ConvolutionForwardSpecialization::Filter1x1Stride1Pad0; - -static constexpr auto GemmMNKPadding = ck::tensor_operation::device::GemmSpecialization::MNKPadding; - -// Compilation parameters for in[g, n, wi, c] * wei[g, k, x, c] = out[g, n, wo, k] -using device_grouped_conv1d_fwd_xdl_gnwc_gkxc_gnwk_f16_instances = - std::tuple< - // clang-format off - // Default - //########################################| NumDim| A| B| Ds| E| AData| BData| AccData| CShuffle| Ds| EData| A| B| CDE| ConvForward| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| - //########################################| Spatial| Layout| Layout| Layout| Layout| Type| Type| Type| DataType| DataType| Type| Elementwise| Elementwise| Elementwise| Specialization| Specialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| - //########################################| | | | | | | | | | | | Operation| Operation| Operation| | | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| - //########################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | - DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 1, GNWC, GKXC, Empty_Tuple, GNWK, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, ConvFwdDefault, GemmMNKPadding, 1, 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, - DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 1, GNWC, GKXC, Empty_Tuple, GNWK, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, ConvFwdDefault, GemmMNKPadding, 1, 256, 128, 256, 32, 8, 8, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, - DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 1, GNWC, GKXC, Empty_Tuple, GNWK, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, ConvFwdDefault, GemmMNKPadding, 1, 128, 128, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 8>, - DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 1, GNWC, GKXC, Empty_Tuple, GNWK, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, ConvFwdDefault, GemmMNKPadding, 1, 256, 128, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, - DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 1, GNWC, GKXC, Empty_Tuple, GNWK, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, ConvFwdDefault, GemmMNKPadding, 1, 128, 128, 64, 32, 8, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 8>, - DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 1, GNWC, GKXC, Empty_Tuple, GNWK, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, ConvFwdDefault, GemmMNKPadding, 1, 128, 64, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 8>, - DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 1, GNWC, GKXC, Empty_Tuple, GNWK, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, ConvFwdDefault, GemmMNKPadding, 1, 64, 64, 64, 32, 8, 8, 32, 32, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 8>, - DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 1, GNWC, GKXC, Empty_Tuple, GNWK, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, ConvFwdDefault, GemmMNKPadding, 1, 256, 128, 64, 32, 8, 8, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, - DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 1, GNWC, GKXC, Empty_Tuple, GNWK, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, ConvFwdDefault, GemmMNKPadding, 1, 256, 64, 128, 32, 8, 8, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, - DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 1, GNWC, GKXC, Empty_Tuple, GNWK, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, ConvFwdDefault, GemmMNKPadding, 1, 128, 128, 32, 32, 8, 8, 32, 32, 2, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 8>, - DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 1, GNWC, GKXC, Empty_Tuple, GNWK, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, ConvFwdDefault, GemmMNKPadding, 1, 128, 32, 128, 32, 8, 8, 32, 32, 1, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 8>, - DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 1, GNWC, GKXC, Empty_Tuple, GNWK, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, ConvFwdDefault, GemmMNKPadding, 1, 64, 64, 32, 32, 8, 8, 32, 32, 2, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 8>, - DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 1, GNWC, GKXC, Empty_Tuple, GNWK, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, ConvFwdDefault, GemmMNKPadding, 1, 64, 32, 64, 32, 8, 8, 32, 32, 1, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 8>, - - // Filter1x1Pad0 - //########################################| NumDim| A| B| Ds| E| AData| BData| AccData| CShuffle| Ds| EData| A| B| CDE| ConvForward| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| - //########################################| Spatial| Layout| Layout| Layout| Layout| Type| Type| Type| DataType| DataType| Type| Elementwise| Elementwise| Elementwise| Specialization| Specialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| - //########################################| | | | | | | | | | | | Operation| Operation| Operation| | | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| - //########################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | - DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 1, GNWC, GKXC, Empty_Tuple, GNWK, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, GemmMNKPadding, 1, 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, - DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 1, GNWC, GKXC, Empty_Tuple, GNWK, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, GemmMNKPadding, 1, 256, 128, 256, 32, 8, 8, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, - DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 1, GNWC, GKXC, Empty_Tuple, GNWK, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, GemmMNKPadding, 1, 128, 128, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 8>, - DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 1, GNWC, GKXC, Empty_Tuple, GNWK, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, GemmMNKPadding, 1, 256, 128, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, - DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 1, GNWC, GKXC, Empty_Tuple, GNWK, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, GemmMNKPadding, 1, 128, 128, 64, 32, 8, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 8>, - DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 1, GNWC, GKXC, Empty_Tuple, GNWK, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, GemmMNKPadding, 1, 128, 64, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 8>, - DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 1, GNWC, GKXC, Empty_Tuple, GNWK, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, GemmMNKPadding, 1, 64, 64, 64, 32, 8, 8, 32, 32, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 8>, - DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 1, GNWC, GKXC, Empty_Tuple, GNWK, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, GemmMNKPadding, 1, 256, 128, 64, 32, 8, 8, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, - DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 1, GNWC, GKXC, Empty_Tuple, GNWK, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, GemmMNKPadding, 1, 256, 64, 128, 32, 8, 8, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, - DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 1, GNWC, GKXC, Empty_Tuple, GNWK, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, GemmMNKPadding, 1, 128, 128, 32, 32, 8, 8, 32, 32, 2, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 8>, - DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 1, GNWC, GKXC, Empty_Tuple, GNWK, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, GemmMNKPadding, 1, 128, 32, 128, 32, 8, 8, 32, 32, 1, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 8>, - DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 1, GNWC, GKXC, Empty_Tuple, GNWK, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, GemmMNKPadding, 1, 64, 64, 32, 32, 8, 8, 32, 32, 2, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 8>, - DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 1, GNWC, GKXC, Empty_Tuple, GNWK, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, GemmMNKPadding, 1, 64, 32, 64, 32, 8, 8, 32, 32, 1, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 8>, - - // Filter1x1Stride1Pad0 - //########################################| NumDim| A| B| Ds| E| AData| BData| AccData| CShuffle| Ds| EData| A| B| CDE| ConvForward| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| - //########################################| Spatial| Layout| Layout| Layout| Layout| Type| Type| Type| DataType| DataType| Type| Elementwise| Elementwise| Elementwise| Specialization| Specialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| - //########################################| | | | | | | | | | | | Operation| Operation| Operation| | | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| - //########################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | - DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 1, GNWC, GKXC, Empty_Tuple, GNWK, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, GemmMNKPadding, 1, 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, - DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 1, GNWC, GKXC, Empty_Tuple, GNWK, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, GemmMNKPadding, 1, 256, 128, 256, 32, 8, 8, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, - DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 1, GNWC, GKXC, Empty_Tuple, GNWK, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, GemmMNKPadding, 1, 128, 128, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 8>, - DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 1, GNWC, GKXC, Empty_Tuple, GNWK, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, GemmMNKPadding, 1, 256, 128, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, - DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 1, GNWC, GKXC, Empty_Tuple, GNWK, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, GemmMNKPadding, 1, 128, 128, 64, 32, 8, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 8>, - DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 1, GNWC, GKXC, Empty_Tuple, GNWK, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, GemmMNKPadding, 1, 128, 64, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 8>, - DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 1, GNWC, GKXC, Empty_Tuple, GNWK, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, GemmMNKPadding, 1, 64, 64, 64, 32, 8, 8, 32, 32, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 8>, - DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 1, GNWC, GKXC, Empty_Tuple, GNWK, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, GemmMNKPadding, 1, 256, 128, 64, 32, 8, 8, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, - DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 1, GNWC, GKXC, Empty_Tuple, GNWK, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, GemmMNKPadding, 1, 256, 64, 128, 32, 8, 8, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, - DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 1, GNWC, GKXC, Empty_Tuple, GNWK, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, GemmMNKPadding, 1, 128, 128, 32, 32, 8, 8, 32, 32, 2, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 8>, - DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 1, GNWC, GKXC, Empty_Tuple, GNWK, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, GemmMNKPadding, 1, 128, 32, 128, 32, 8, 8, 32, 32, 1, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 8>, - DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 1, GNWC, GKXC, Empty_Tuple, GNWK, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, GemmMNKPadding, 1, 64, 64, 32, 32, 8, 8, 32, 32, 2, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 8>, - DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 1, GNWC, GKXC, Empty_Tuple, GNWK, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, GemmMNKPadding, 1, 64, 32, 64, 32, 8, 8, 32, 32, 1, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 8> - // clang-format on - >; - void add_device_grouped_conv1d_fwd_xdl_gnwc_gkxc_gnwk_f16_instances( std::vector>>& instances) { add_device_operation_instances(instances, - device_grouped_conv1d_fwd_xdl_gnwc_gkxc_gnwk_f16_instances{}); + device_grouped_conv_fwd_xdl_f16_instances<1, + GNWC, + GKXC, + Empty_Tuple, + GNWK, + ConvFwdDefault>{}); + add_device_operation_instances(instances, + device_grouped_conv_fwd_xdl_f16_instances<1, + GNWC, + GKXC, + Empty_Tuple, + GNWK, + ConvFwd1x1P0>{}); + add_device_operation_instances(instances, + device_grouped_conv_fwd_xdl_f16_instances<1, + GNWC, + GKXC, + Empty_Tuple, + GNWK, + ConvFwd1x1S1P0>{}); } } // namespace instance diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv1d_fwd/device_grouped_conv1d_fwd_xdl_gnwc_gkxc_gnwk_f32_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv1d_fwd/device_grouped_conv1d_fwd_xdl_gnwc_gkxc_gnwk_f32_instance.cpp index 3145b716402e4e5cf8e542dc991bc7db0a655e74..1db1226b7f72f1e106561d875141df2933212154 100644 --- a/library/src/tensor_operation_instance/gpu/grouped_conv1d_fwd/device_grouped_conv1d_fwd_xdl_gnwc_gkxc_gnwk_f32_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/grouped_conv1d_fwd/device_grouped_conv1d_fwd_xdl_gnwc_gkxc_gnwk_f32_instance.cpp @@ -1,15 +1,7 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. - -#include - -#include "ck/ck.hpp" -#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" -#include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_d_xdl_cshuffle.hpp" -#include "ck/tensor_operation/gpu/device/convolution_forward_specialization.hpp" -#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" -#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. +#include "ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_instance.hpp" #include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" namespace ck { @@ -17,93 +9,6 @@ namespace tensor_operation { namespace device { namespace instance { -using F32 = float; - -using Empty_Tuple = ck::Tuple<>; - -template -using S = ck::Sequence; - -using GNWC = ck::tensor_layout::convolution::GNWC; -using GKXC = ck::tensor_layout::convolution::GKXC; -using GNWK = ck::tensor_layout::convolution::GNWK; - -using PassThrough = ck::tensor_operation::element_wise::PassThrough; - -static constexpr auto ConvFwdDefault = - ck::tensor_operation::device::ConvolutionForwardSpecialization::Default; - -static constexpr auto ConvFwd1x1P0 = - ck::tensor_operation::device::ConvolutionForwardSpecialization::Filter1x1Pad0; - -static constexpr auto ConvFwd1x1S1P0 = - ck::tensor_operation::device::ConvolutionForwardSpecialization::Filter1x1Stride1Pad0; - -static constexpr auto GemmMNKPadding = ck::tensor_operation::device::GemmSpecialization::MNKPadding; - -// Compilation parameters for in[g, n, wi, c] * wei[g, k, x, c] = out[g, n, wo, k] -using device_grouped_conv1d_fwd_xdl_gnwc_gkxc_gnwk_f32_instances = - std::tuple< - // clang-format off - // Default - //########################################| NumDim| A| B| Ds| E| AData| BData| AccData| CShuffle| Ds| EData| A| B| CDE| ConvForward| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| - //########################################| Spatial| Layout| Layout| Layout| Layout| Type| Type| Type| DataType| DataType| Type| Elementwise| Elementwise| Elementwise| Specialization| Specialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| - //########################################| | | | | | | | | | | | Operation| Operation| Operation| | | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| - //########################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | - DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 1, GNWC, GKXC, Empty_Tuple, GNWK, F32, F32, F32, F32, Empty_Tuple, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, GemmMNKPadding, 1, 256, 256, 128, 16, 4, 4, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 16, 1, 16>, 4>, - DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 1, GNWC, GKXC, Empty_Tuple, GNWK, F32, F32, F32, F32, Empty_Tuple, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, GemmMNKPadding, 1, 256, 128, 256, 16, 4, 4, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 16, 1, 16>, 4>, - DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 1, GNWC, GKXC, Empty_Tuple, GNWK, F32, F32, F32, F32, Empty_Tuple, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, GemmMNKPadding, 1, 128, 128, 128, 16, 4, 4, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 8, 1, 16>, 4>, - DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 1, GNWC, GKXC, Empty_Tuple, GNWK, F32, F32, F32, F32, Empty_Tuple, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, GemmMNKPadding, 1, 256, 128, 128, 16, 4, 4, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 16, 1, 16>, 4>, - DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 1, GNWC, GKXC, Empty_Tuple, GNWK, F32, F32, F32, F32, Empty_Tuple, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, GemmMNKPadding, 1, 128, 128, 64, 16, 4, 4, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 16, 1, 8>, 4>, - DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 1, GNWC, GKXC, Empty_Tuple, GNWK, F32, F32, F32, F32, Empty_Tuple, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, GemmMNKPadding, 1, 128, 64, 128, 16, 4, 4, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 8, 1, 16>, 4>, - DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 1, GNWC, GKXC, Empty_Tuple, GNWK, F32, F32, F32, F32, Empty_Tuple, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, GemmMNKPadding, 1, 64, 64, 64, 16, 4, 4, 32, 32, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 8, 1, 8>, 4>, - DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 1, GNWC, GKXC, Empty_Tuple, GNWK, F32, F32, F32, F32, Empty_Tuple, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, GemmMNKPadding, 1, 256, 128, 64, 16, 4, 4, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 16, 1, 16>, 4>, - DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 1, GNWC, GKXC, Empty_Tuple, GNWK, F32, F32, F32, F32, Empty_Tuple, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, GemmMNKPadding, 1, 256, 64, 128, 16, 4, 4, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 16, 1, 16>, 4>, - DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 1, GNWC, GKXC, Empty_Tuple, GNWK, F32, F32, F32, F32, Empty_Tuple, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, GemmMNKPadding, 1, 128, 128, 32, 16, 4, 4, 32, 32, 2, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 16, 1, 8>, 4>, - DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 1, GNWC, GKXC, Empty_Tuple, GNWK, F32, F32, F32, F32, Empty_Tuple, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, GemmMNKPadding, 1, 128, 32, 128, 16, 4, 4, 32, 32, 1, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 8, 1, 16>, 4>, - DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 1, GNWC, GKXC, Empty_Tuple, GNWK, F32, F32, F32, F32, Empty_Tuple, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, GemmMNKPadding, 1, 64, 64, 32, 16, 4, 4, 32, 32, 2, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 8, 1, 8>, 4>, - DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 1, GNWC, GKXC, Empty_Tuple, GNWK, F32, F32, F32, F32, Empty_Tuple, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, GemmMNKPadding, 1, 64, 32, 64, 16, 4, 4, 32, 32, 1, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 8, 1, 8>, 4>, - - // Filter1x1Pad0 - //########################################| NumDim| A| B| Ds| E| AData| BData| AccData| CShuffle| Ds| EData| A| B| CDE| ConvForward| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| - //########################################| Spatial| Layout| Layout| Layout| Layout| Type| Type| Type| DataType| DataType| Type| Elementwise| Elementwise| Elementwise| Specialization| Specialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| - //########################################| | | | | | | | | | | | Operation| Operation| Operation| | | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| - //########################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | - DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 1, GNWC, GKXC, Empty_Tuple, GNWK, F32, F32, F32, F32, Empty_Tuple, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, GemmMNKPadding, 1, 256, 256, 128, 16, 4, 4, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 16, 1, 16>, 4>, - DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 1, GNWC, GKXC, Empty_Tuple, GNWK, F32, F32, F32, F32, Empty_Tuple, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, GemmMNKPadding, 1, 256, 128, 256, 16, 4, 4, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 16, 1, 16>, 4>, - DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 1, GNWC, GKXC, Empty_Tuple, GNWK, F32, F32, F32, F32, Empty_Tuple, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, GemmMNKPadding, 1, 128, 128, 128, 16, 4, 4, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 8, 1, 16>, 4>, - DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 1, GNWC, GKXC, Empty_Tuple, GNWK, F32, F32, F32, F32, Empty_Tuple, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, GemmMNKPadding, 1, 256, 128, 128, 16, 4, 4, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 16, 1, 16>, 4>, - DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 1, GNWC, GKXC, Empty_Tuple, GNWK, F32, F32, F32, F32, Empty_Tuple, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, GemmMNKPadding, 1, 128, 128, 64, 16, 4, 4, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 16, 1, 8>, 4>, - DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 1, GNWC, GKXC, Empty_Tuple, GNWK, F32, F32, F32, F32, Empty_Tuple, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, GemmMNKPadding, 1, 128, 64, 128, 16, 4, 4, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 8, 1, 16>, 4>, - DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 1, GNWC, GKXC, Empty_Tuple, GNWK, F32, F32, F32, F32, Empty_Tuple, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, GemmMNKPadding, 1, 64, 64, 64, 16, 4, 4, 32, 32, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 8, 1, 8>, 4>, - DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 1, GNWC, GKXC, Empty_Tuple, GNWK, F32, F32, F32, F32, Empty_Tuple, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, GemmMNKPadding, 1, 256, 128, 64, 16, 4, 4, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 16, 1, 16>, 4>, - DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 1, GNWC, GKXC, Empty_Tuple, GNWK, F32, F32, F32, F32, Empty_Tuple, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, GemmMNKPadding, 1, 256, 64, 128, 16, 4, 4, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 16, 1, 16>, 4>, - DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 1, GNWC, GKXC, Empty_Tuple, GNWK, F32, F32, F32, F32, Empty_Tuple, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, GemmMNKPadding, 1, 128, 128, 32, 16, 4, 4, 32, 32, 2, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 16, 1, 8>, 4>, - DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 1, GNWC, GKXC, Empty_Tuple, GNWK, F32, F32, F32, F32, Empty_Tuple, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, GemmMNKPadding, 1, 128, 32, 128, 16, 4, 4, 32, 32, 1, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 8, 1, 16>, 4>, - DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 1, GNWC, GKXC, Empty_Tuple, GNWK, F32, F32, F32, F32, Empty_Tuple, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, GemmMNKPadding, 1, 64, 64, 32, 16, 4, 4, 32, 32, 2, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 8, 1, 8>, 4>, - DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 1, GNWC, GKXC, Empty_Tuple, GNWK, F32, F32, F32, F32, Empty_Tuple, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, GemmMNKPadding, 1, 64, 32, 64, 16, 4, 4, 32, 32, 1, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 8, 1, 8>, 4>, - - // Filter1x1Stride1Pad0 - //########################################| NumDim| A| B| Ds| E| AData| BData| AccData| CShuffle| Ds| EData| A| B| CDE| ConvForward| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| - //########################################| Spatial| Layout| Layout| Layout| Layout| Type| Type| Type| DataType| DataType| Type| Elementwise| Elementwise| Elementwise| Specialization| Specialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| - //########################################| | | | | | | | | | | | Operation| Operation| Operation| | | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| - //########################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | - DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 1, GNWC, GKXC, Empty_Tuple, GNWK, F32, F32, F32, F32, Empty_Tuple, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, GemmMNKPadding, 1, 256, 256, 128, 16, 4, 4, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 16, 1, 16>, 4>, - DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 1, GNWC, GKXC, Empty_Tuple, GNWK, F32, F32, F32, F32, Empty_Tuple, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, GemmMNKPadding, 1, 256, 128, 256, 16, 4, 4, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 16, 1, 16>, 4>, - DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 1, GNWC, GKXC, Empty_Tuple, GNWK, F32, F32, F32, F32, Empty_Tuple, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, GemmMNKPadding, 1, 128, 128, 128, 16, 4, 4, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 8, 1, 16>, 4>, - DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 1, GNWC, GKXC, Empty_Tuple, GNWK, F32, F32, F32, F32, Empty_Tuple, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, GemmMNKPadding, 1, 256, 128, 128, 16, 4, 4, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 16, 1, 16>, 4>, - DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 1, GNWC, GKXC, Empty_Tuple, GNWK, F32, F32, F32, F32, Empty_Tuple, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, GemmMNKPadding, 1, 128, 128, 64, 16, 4, 4, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 16, 1, 8>, 4>, - DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 1, GNWC, GKXC, Empty_Tuple, GNWK, F32, F32, F32, F32, Empty_Tuple, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, GemmMNKPadding, 1, 128, 64, 128, 16, 4, 4, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 8, 1, 16>, 4>, - DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 1, GNWC, GKXC, Empty_Tuple, GNWK, F32, F32, F32, F32, Empty_Tuple, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, GemmMNKPadding, 1, 64, 64, 64, 16, 4, 4, 32, 32, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 8, 1, 8>, 4>, - DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 1, GNWC, GKXC, Empty_Tuple, GNWK, F32, F32, F32, F32, Empty_Tuple, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, GemmMNKPadding, 1, 256, 128, 64, 16, 4, 4, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 16, 1, 16>, 4>, - DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 1, GNWC, GKXC, Empty_Tuple, GNWK, F32, F32, F32, F32, Empty_Tuple, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, GemmMNKPadding, 1, 256, 64, 128, 16, 4, 4, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 16, 1, 16>, 4>, - DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 1, GNWC, GKXC, Empty_Tuple, GNWK, F32, F32, F32, F32, Empty_Tuple, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, GemmMNKPadding, 1, 128, 128, 32, 16, 4, 4, 32, 32, 2, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 16, 1, 8>, 4>, - DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 1, GNWC, GKXC, Empty_Tuple, GNWK, F32, F32, F32, F32, Empty_Tuple, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, GemmMNKPadding, 1, 128, 32, 128, 16, 4, 4, 32, 32, 1, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 8, 1, 16>, 4>, - DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 1, GNWC, GKXC, Empty_Tuple, GNWK, F32, F32, F32, F32, Empty_Tuple, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, GemmMNKPadding, 1, 64, 64, 32, 16, 4, 4, 32, 32, 2, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 8, 1, 8>, 4>, - DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 1, GNWC, GKXC, Empty_Tuple, GNWK, F32, F32, F32, F32, Empty_Tuple, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, GemmMNKPadding, 1, 64, 32, 64, 16, 4, 4, 32, 32, 1, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 8, 1, 8>, 4> - // clang-format on - >; - void add_device_grouped_conv1d_fwd_xdl_gnwc_gkxc_gnwk_f32_instances( std::vector>>& instances) { add_device_operation_instances(instances, - device_grouped_conv1d_fwd_xdl_gnwc_gkxc_gnwk_f32_instances{}); + device_grouped_conv_fwd_xdl_f32_instances<1, + GNWC, + GKXC, + Empty_Tuple, + GNWK, + ConvFwdDefault>{}); + add_device_operation_instances(instances, + device_grouped_conv_fwd_xdl_f32_instances<1, + GNWC, + GKXC, + Empty_Tuple, + GNWK, + ConvFwd1x1P0>{}); + add_device_operation_instances(instances, + device_grouped_conv_fwd_xdl_f32_instances<1, + GNWC, + GKXC, + Empty_Tuple, + GNWK, + ConvFwd1x1S1P0>{}); } } // namespace instance diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv1d_fwd/device_grouped_conv1d_fwd_xdl_gnwc_gkxc_gnwk_int8_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv1d_fwd/device_grouped_conv1d_fwd_xdl_gnwc_gkxc_gnwk_int8_instance.cpp index cde93f902c95d1c28f85b7a7627089fe3e97f841..6fbf631766f14b081e483328b1a501a46736e018 100644 --- a/library/src/tensor_operation_instance/gpu/grouped_conv1d_fwd/device_grouped_conv1d_fwd_xdl_gnwc_gkxc_gnwk_int8_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/grouped_conv1d_fwd/device_grouped_conv1d_fwd_xdl_gnwc_gkxc_gnwk_int8_instance.cpp @@ -1,15 +1,7 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. - -#include - -#include "ck/ck.hpp" -#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" -#include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_d_xdl_cshuffle.hpp" -#include "ck/tensor_operation/gpu/device/convolution_forward_specialization.hpp" -#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" -#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. +#include "ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_instance.hpp" #include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" namespace ck { @@ -17,90 +9,6 @@ namespace tensor_operation { namespace device { namespace instance { -using Empty_Tuple = ck::Tuple<>; - -template -using S = ck::Sequence; - -using GNWC = ck::tensor_layout::convolution::GNWC; -using GKXC = ck::tensor_layout::convolution::GKXC; -using GNWK = ck::tensor_layout::convolution::GNWK; - -using PassThrough = ck::tensor_operation::element_wise::PassThrough; - -static constexpr auto ConvFwdDefault = - ck::tensor_operation::device::ConvolutionForwardSpecialization::Default; - -static constexpr auto ConvFwd1x1P0 = - ck::tensor_operation::device::ConvolutionForwardSpecialization::Filter1x1Pad0; - -static constexpr auto ConvFwd1x1S1P0 = - ck::tensor_operation::device::ConvolutionForwardSpecialization::Filter1x1Stride1Pad0; - -static constexpr auto GemmMNKPadding = ck::tensor_operation::device::GemmSpecialization::MNKPadding; - -// Compilation parameters for in[g, n, wi, c] * wei[g, k, x, c] = out[g, n, wo, k] -using device_grouped_conv1d_fwd_xdl_gnwc_gkxc_gnwk_int8_instances = std::tuple< - // clang-format off - // Default - //########################################| NumDim| A| B| Ds| E| AData| BData| AccData| CShuffle| Ds| EData| A| B| CDE| ConvForward| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| - //########################################| Spatial| Layout| Layout| Layout| Layout| Type| Type| Type| DataType| DataType| Type| Elementwise| Elementwise| Elementwise| Specialization| Specialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| - //########################################| | | | | | | | | | | | Operation| Operation| Operation| | | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| - //########################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | - DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 1, GNWC, GKXC, Empty_Tuple, GNWK, int8_t, int8_t, int32_t, int8_t, Empty_Tuple, int8_t, PassThrough, PassThrough, PassThrough, ConvFwdDefault, GemmMNKPadding, 1, 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, - DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 1, GNWC, GKXC, Empty_Tuple, GNWK, int8_t, int8_t, int32_t, int8_t, Empty_Tuple, int8_t, PassThrough, PassThrough, PassThrough, ConvFwdDefault, GemmMNKPadding, 1, 256, 128, 256, 32, 8, 8, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, - DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 1, GNWC, GKXC, Empty_Tuple, GNWK, int8_t, int8_t, int32_t, int8_t, Empty_Tuple, int8_t, PassThrough, PassThrough, PassThrough, ConvFwdDefault, GemmMNKPadding, 1, 128, 128, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 8>, - DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 1, GNWC, GKXC, Empty_Tuple, GNWK, int8_t, int8_t, int32_t, int8_t, Empty_Tuple, int8_t, PassThrough, PassThrough, PassThrough, ConvFwdDefault, GemmMNKPadding, 1, 256, 128, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, - DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 1, GNWC, GKXC, Empty_Tuple, GNWK, int8_t, int8_t, int32_t, int8_t, Empty_Tuple, int8_t, PassThrough, PassThrough, PassThrough, ConvFwdDefault, GemmMNKPadding, 1, 128, 128, 64, 32, 8, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 8>, - DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 1, GNWC, GKXC, Empty_Tuple, GNWK, int8_t, int8_t, int32_t, int8_t, Empty_Tuple, int8_t, PassThrough, PassThrough, PassThrough, ConvFwdDefault, GemmMNKPadding, 1, 128, 64, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 8>, - DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 1, GNWC, GKXC, Empty_Tuple, GNWK, int8_t, int8_t, int32_t, int8_t, Empty_Tuple, int8_t, PassThrough, PassThrough, PassThrough, ConvFwdDefault, GemmMNKPadding, 1, 64, 64, 64, 32, 8, 8, 32, 32, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 8>, - DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 1, GNWC, GKXC, Empty_Tuple, GNWK, int8_t, int8_t, int32_t, int8_t, Empty_Tuple, int8_t, PassThrough, PassThrough, PassThrough, ConvFwdDefault, GemmMNKPadding, 1, 256, 128, 64, 32, 8, 8, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, - DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 1, GNWC, GKXC, Empty_Tuple, GNWK, int8_t, int8_t, int32_t, int8_t, Empty_Tuple, int8_t, PassThrough, PassThrough, PassThrough, ConvFwdDefault, GemmMNKPadding, 1, 256, 64, 128, 32, 8, 8, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, - DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 1, GNWC, GKXC, Empty_Tuple, GNWK, int8_t, int8_t, int32_t, int8_t, Empty_Tuple, int8_t, PassThrough, PassThrough, PassThrough, ConvFwdDefault, GemmMNKPadding, 1, 128, 128, 32, 32, 8, 8, 32, 32, 2, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 8>, - DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 1, GNWC, GKXC, Empty_Tuple, GNWK, int8_t, int8_t, int32_t, int8_t, Empty_Tuple, int8_t, PassThrough, PassThrough, PassThrough, ConvFwdDefault, GemmMNKPadding, 1, 128, 32, 128, 32, 8, 8, 32, 32, 1, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 8>, - DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 1, GNWC, GKXC, Empty_Tuple, GNWK, int8_t, int8_t, int32_t, int8_t, Empty_Tuple, int8_t, PassThrough, PassThrough, PassThrough, ConvFwdDefault, GemmMNKPadding, 1, 64, 64, 32, 32, 8, 8, 32, 32, 2, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 8>, - DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 1, GNWC, GKXC, Empty_Tuple, GNWK, int8_t, int8_t, int32_t, int8_t, Empty_Tuple, int8_t, PassThrough, PassThrough, PassThrough, ConvFwdDefault, GemmMNKPadding, 1, 64, 32, 64, 32, 8, 8, 32, 32, 1, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 8>, - - // Filter1x1Pad0 - //########################################| NumDim| A| B| Ds| E| AData| BData| AccData| CShuffle| Ds| EData| A| B| CDE| ConvForward| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| - //########################################| Spatial| Layout| Layout| Layout| Layout| Type| Type| Type| DataType| DataType| Type| Elementwise| Elementwise| Elementwise| Specialization| Specialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| - //########################################| | | | | | | | | | | | Operation| Operation| Operation| | | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| - //########################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | - DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 1, GNWC, GKXC, Empty_Tuple, GNWK, int8_t, int8_t, int32_t, int8_t, Empty_Tuple, int8_t, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, GemmMNKPadding, 1, 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, - DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 1, GNWC, GKXC, Empty_Tuple, GNWK, int8_t, int8_t, int32_t, int8_t, Empty_Tuple, int8_t, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, GemmMNKPadding, 1, 256, 128, 256, 32, 8, 8, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, - DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 1, GNWC, GKXC, Empty_Tuple, GNWK, int8_t, int8_t, int32_t, int8_t, Empty_Tuple, int8_t, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, GemmMNKPadding, 1, 128, 128, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 8>, - DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 1, GNWC, GKXC, Empty_Tuple, GNWK, int8_t, int8_t, int32_t, int8_t, Empty_Tuple, int8_t, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, GemmMNKPadding, 1, 256, 128, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, - DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 1, GNWC, GKXC, Empty_Tuple, GNWK, int8_t, int8_t, int32_t, int8_t, Empty_Tuple, int8_t, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, GemmMNKPadding, 1, 128, 128, 64, 32, 8, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 8>, - DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 1, GNWC, GKXC, Empty_Tuple, GNWK, int8_t, int8_t, int32_t, int8_t, Empty_Tuple, int8_t, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, GemmMNKPadding, 1, 128, 64, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 8>, - DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 1, GNWC, GKXC, Empty_Tuple, GNWK, int8_t, int8_t, int32_t, int8_t, Empty_Tuple, int8_t, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, GemmMNKPadding, 1, 64, 64, 64, 32, 8, 8, 32, 32, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 8>, - DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 1, GNWC, GKXC, Empty_Tuple, GNWK, int8_t, int8_t, int32_t, int8_t, Empty_Tuple, int8_t, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, GemmMNKPadding, 1, 256, 128, 64, 32, 8, 8, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, - DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 1, GNWC, GKXC, Empty_Tuple, GNWK, int8_t, int8_t, int32_t, int8_t, Empty_Tuple, int8_t, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, GemmMNKPadding, 1, 256, 64, 128, 32, 8, 8, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, - DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 1, GNWC, GKXC, Empty_Tuple, GNWK, int8_t, int8_t, int32_t, int8_t, Empty_Tuple, int8_t, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, GemmMNKPadding, 1, 128, 128, 32, 32, 8, 8, 32, 32, 2, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 8>, - DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 1, GNWC, GKXC, Empty_Tuple, GNWK, int8_t, int8_t, int32_t, int8_t, Empty_Tuple, int8_t, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, GemmMNKPadding, 1, 128, 32, 128, 32, 8, 8, 32, 32, 1, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 8>, - DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 1, GNWC, GKXC, Empty_Tuple, GNWK, int8_t, int8_t, int32_t, int8_t, Empty_Tuple, int8_t, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, GemmMNKPadding, 1, 64, 64, 32, 32, 8, 8, 32, 32, 2, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 8>, - DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 1, GNWC, GKXC, Empty_Tuple, GNWK, int8_t, int8_t, int32_t, int8_t, Empty_Tuple, int8_t, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, GemmMNKPadding, 1, 64, 32, 64, 32, 8, 8, 32, 32, 1, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 8>, - - // Filter1x1Stride1Pad0 - //########################################| NumDim| A| B| Ds| E| AData| BData| AccData| CShuffle| Ds| EData| A| B| CDE| ConvForward| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| - //########################################| Spatial| Layout| Layout| Layout| Layout| Type| Type| Type| DataType| DataType| Type| Elementwise| Elementwise| Elementwise| Specialization| Specialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| - //########################################| | | | | | | | | | | | Operation| Operation| Operation| | | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| - //########################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | - DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 1, GNWC, GKXC, Empty_Tuple, GNWK, int8_t, int8_t, int32_t, int8_t, Empty_Tuple, int8_t, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, GemmMNKPadding, 1, 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, - DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 1, GNWC, GKXC, Empty_Tuple, GNWK, int8_t, int8_t, int32_t, int8_t, Empty_Tuple, int8_t, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, GemmMNKPadding, 1, 256, 128, 256, 32, 8, 8, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, - DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 1, GNWC, GKXC, Empty_Tuple, GNWK, int8_t, int8_t, int32_t, int8_t, Empty_Tuple, int8_t, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, GemmMNKPadding, 1, 128, 128, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 8>, - DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 1, GNWC, GKXC, Empty_Tuple, GNWK, int8_t, int8_t, int32_t, int8_t, Empty_Tuple, int8_t, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, GemmMNKPadding, 1, 256, 128, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, - DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 1, GNWC, GKXC, Empty_Tuple, GNWK, int8_t, int8_t, int32_t, int8_t, Empty_Tuple, int8_t, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, GemmMNKPadding, 1, 128, 128, 64, 32, 8, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 8>, - DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 1, GNWC, GKXC, Empty_Tuple, GNWK, int8_t, int8_t, int32_t, int8_t, Empty_Tuple, int8_t, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, GemmMNKPadding, 1, 128, 64, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 8>, - DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 1, GNWC, GKXC, Empty_Tuple, GNWK, int8_t, int8_t, int32_t, int8_t, Empty_Tuple, int8_t, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, GemmMNKPadding, 1, 64, 64, 64, 32, 8, 8, 32, 32, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 8>, - DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 1, GNWC, GKXC, Empty_Tuple, GNWK, int8_t, int8_t, int32_t, int8_t, Empty_Tuple, int8_t, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, GemmMNKPadding, 1, 256, 128, 64, 32, 8, 8, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, - DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 1, GNWC, GKXC, Empty_Tuple, GNWK, int8_t, int8_t, int32_t, int8_t, Empty_Tuple, int8_t, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, GemmMNKPadding, 1, 256, 64, 128, 32, 8, 8, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, - DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 1, GNWC, GKXC, Empty_Tuple, GNWK, int8_t, int8_t, int32_t, int8_t, Empty_Tuple, int8_t, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, GemmMNKPadding, 1, 128, 128, 32, 32, 8, 8, 32, 32, 2, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 8>, - DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 1, GNWC, GKXC, Empty_Tuple, GNWK, int8_t, int8_t, int32_t, int8_t, Empty_Tuple, int8_t, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, GemmMNKPadding, 1, 128, 32, 128, 32, 8, 8, 32, 32, 1, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 8>, - DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 1, GNWC, GKXC, Empty_Tuple, GNWK, int8_t, int8_t, int32_t, int8_t, Empty_Tuple, int8_t, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, GemmMNKPadding, 1, 64, 64, 32, 32, 8, 8, 32, 32, 2, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 8>, - DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 1, GNWC, GKXC, Empty_Tuple, GNWK, int8_t, int8_t, int32_t, int8_t, Empty_Tuple, int8_t, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, GemmMNKPadding, 1, 64, 32, 64, 32, 8, 8, 32, 32, 1, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 8> - // clang-format on - >; - void add_device_grouped_conv1d_fwd_xdl_gnwc_gkxc_gnwk_int8_instances( std::vector>>& instances) { add_device_operation_instances(instances, - device_grouped_conv1d_fwd_xdl_gnwc_gkxc_gnwk_int8_instances{}); + device_grouped_conv_fwd_xdl_int8_instances<1, + GNWC, + GKXC, + Empty_Tuple, + GNWK, + ConvFwdDefault>{}); + add_device_operation_instances(instances, + device_grouped_conv_fwd_xdl_int8_instances<1, + GNWC, + GKXC, + Empty_Tuple, + GNWK, + ConvFwd1x1P0>{}); + add_device_operation_instances(instances, + device_grouped_conv_fwd_xdl_int8_instances<1, + GNWC, + GKXC, + Empty_Tuple, + GNWK, + ConvFwd1x1S1P0>{}); } } // namespace instance diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_data/CMakeLists.txt b/library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_data/CMakeLists.txt index 3b2968d48efd3e0061c36df1febb290b64264f73..85ec0f55aa78db969e8731213094f786a281d8d1 100644 --- a/library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_data/CMakeLists.txt +++ b/library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_data/CMakeLists.txt @@ -1,3 +1,8 @@ add_instance_library(device_grouped_conv2d_bwd_data_instance device_grouped_conv2d_bwd_data_xdl_gnhwc_gkyxc_gnhwk_f16_instance.cpp + device_grouped_conv2d_bwd_data_xdl_gnhwc_gkyxc_gnhwk_bf16_instance.cpp + device_grouped_conv2d_bwd_data_xdl_gnhwc_gkyxc_gnhwk_f32_instance.cpp + device_grouped_conv2d_bwd_data_xdl_nhwgc_gkyxc_nhwgk_f16_instance.cpp + device_grouped_conv2d_bwd_data_xdl_nhwgc_gkyxc_nhwgk_bf16_instance.cpp + device_grouped_conv2d_bwd_data_xdl_nhwgc_gkyxc_nhwgk_f32_instance.cpp ) diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_data/device_grouped_conv2d_bwd_data_xdl_gnhwc_gkyxc_gnhwk_bf16_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_data/device_grouped_conv2d_bwd_data_xdl_gnhwc_gkyxc_gnhwk_bf16_instance.cpp new file mode 100644 index 0000000000000000000000000000000000000000..226dca50833921c76d5c236988c34d2d01202475 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_data/device_grouped_conv2d_bwd_data_xdl_gnhwc_gkyxc_gnhwk_bf16_instance.cpp @@ -0,0 +1,49 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" +#include "ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_data/device_grouped_conv_bwd_data_xdl_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { +// Compilation parameters for out[g, n, hi, wi, c] * wei[g, k, y, x, c] = in[g, n, ho, wo, k] +void add_device_grouped_conv2d_bwd_data_xdl_gnhwk_gkyxc_gnhwc_bf16_instances( + std::vector>>& instances) +{ + // 1. Default + add_device_operation_instances( + instances, + device_grouped_conv_bwd_data_xdl_bf16_instances<2, + GNHWK, + GKYXC, + Empty_Tuple, + GNHWC, + ConvBwdDataDefault>{}); + // 2. Filter1x1Stride1Pad0 + add_device_operation_instances( + instances, + device_grouped_conv_bwd_data_xdl_bf16_instances<2, + GNHWK, + GKYXC, + Empty_Tuple, + GNHWC, + ConvBwdDataFilter1x1Stride1Pad0>{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_data/device_grouped_conv2d_bwd_data_xdl_gnhwc_gkyxc_gnhwk_f16_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_data/device_grouped_conv2d_bwd_data_xdl_gnhwc_gkyxc_gnhwk_f16_instance.cpp index 3d604d42cc3a9fcbf8e9750ab5e3b2064f0bc161..64fbf8bbf2b341a598568d0de855083f2d416616 100644 --- a/library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_data/device_grouped_conv2d_bwd_data_xdl_gnhwc_gkyxc_gnhwk_f16_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_data/device_grouped_conv2d_bwd_data_xdl_gnhwc_gkyxc_gnhwk_f16_instance.cpp @@ -1,81 +1,15 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. - -#include "ck/ck.hpp" -#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" -#include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_data_multiple_d_xdl_cshuffle_v1.hpp" -#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" +#include "ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_data/device_grouped_conv_bwd_data_xdl_instance.hpp" namespace ck { namespace tensor_operation { namespace device { namespace instance { - -using F16 = ck::half_t; -using F32 = float; - -using Empty_Tuple = ck::Tuple<>; - -template -using S = ck::Sequence; - -using GNHWC = ck::tensor_layout::convolution::GNHWC; -using GKYXC = ck::tensor_layout::convolution::GKYXC; -using GNHWK = ck::tensor_layout::convolution::GNHWK; - -using PassThrough = ck::tensor_operation::element_wise::PassThrough; - -static constexpr auto ConvBwdDataDefault = - ck::tensor_operation::device::ConvolutionBackwardDataSpecialization::Default; - -static constexpr auto ConvBwdDataFilter1x1Stride1Pad0 = - ck::tensor_operation::device::ConvolutionBackwardDataSpecialization::Filter1x1Stride1Pad0; - -using device_grouped_conv2d_bwd_data_xdl_gnhwc_gkyxc_gnhwk_f16_instances = std::tuple< - // clang-format off - // 1. Default - // ##############################################| NDim| ALayout| BLayout| DsLayout| ELayout| AData| BData| AccData| CShuffle| DsData| EData| AElementwise| BElementwise| CDEElementwise| ConvolutionBackward| DoPad| DoPad| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffleMXdl| CShuffleNXdl| CDEBlockTransfer| CDEBlockTransfer| - // ##############################################| Spatial| | | | | Type| Type| Type| DataType| Type| Type| Operation| Operation| Operation| DataSpecialization| GemmM| GemmN| PrefetchStage| Size| Block| Block| Block| | | XDL| XDL| PerWave| PerWave| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraN| PerWave| PerWave| _MBlock_MPerBlock| ScalarPerVector| - // ##############################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | Lengths_AK0_M_AK1| ArrangeOrder| | | PerVector| PerVector_AK1| | Lengths_BK0_N_BK1| ArrangeOrder| | | PerVector| PerVector_BK1| | PerShuffle| PerShuffle| _NBlock_NPerBlock| _NPerBlock| - // ##############################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | - DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1< 2, GNHWK, GKYXC, Empty_Tuple, GNHWC, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, ConvBwdDataDefault, true, true, 1, 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 1, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, - DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1< 2, GNHWK, GKYXC, Empty_Tuple, GNHWC, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, ConvBwdDataDefault, true, true, 1, 256, 128, 256, 32, 8, 8, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 1, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, - DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1< 2, GNHWK, GKYXC, Empty_Tuple, GNHWC, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, ConvBwdDataDefault, true, true, 1, 128, 128, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 1, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 8>, - DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1< 2, GNHWK, GKYXC, Empty_Tuple, GNHWC, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, ConvBwdDataDefault, true, true, 1, 256, 128, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 1, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, - DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1< 2, GNHWK, GKYXC, Empty_Tuple, GNHWC, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, ConvBwdDataDefault, true, true, 1, 128, 128, 64, 32, 8, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 1, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 8>, - DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1< 2, GNHWK, GKYXC, Empty_Tuple, GNHWC, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, ConvBwdDataDefault, true, true, 1, 128, 64, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 1, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 8>, - DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1< 2, GNHWK, GKYXC, Empty_Tuple, GNHWC, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, ConvBwdDataDefault, true, true, 1, 64, 64, 64, 32, 8, 8, 32, 32, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 1, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 8>, - DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1< 2, GNHWK, GKYXC, Empty_Tuple, GNHWC, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, ConvBwdDataDefault, true, true, 1, 256, 128, 64, 32, 8, 8, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 1, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, - DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1< 2, GNHWK, GKYXC, Empty_Tuple, GNHWC, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, ConvBwdDataDefault, true, true, 1, 256, 64, 128, 32, 8, 8, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 1, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, - DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1< 2, GNHWK, GKYXC, Empty_Tuple, GNHWC, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, ConvBwdDataDefault, true, true, 1, 128, 128, 32, 32, 8, 8, 32, 32, 2, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 1, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 8>, - DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1< 2, GNHWK, GKYXC, Empty_Tuple, GNHWC, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, ConvBwdDataDefault, true, true, 1, 128, 32, 128, 32, 8, 8, 32, 32, 1, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 1, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 8>, - DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1< 2, GNHWK, GKYXC, Empty_Tuple, GNHWC, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, ConvBwdDataDefault, true, true, 1, 64, 64, 32, 32, 8, 8, 32, 32, 2, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 1, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 8>, - DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1< 2, GNHWK, GKYXC, Empty_Tuple, GNHWC, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, ConvBwdDataDefault, true, true, 1, 64, 32, 64, 32, 8, 8, 32, 32, 1, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 1, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 8>, - - // 2. Filter1x1Stride1Pad0 - // ##############################################| NDim| ALayout| BLayout| DsLayout| ELayout| AData| BData| AccData| CShuffle| DsData| EData| AElementwise| BElementwise| CDEElementwise| ConvolutionBackward| DoPad| DoPad| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffleMXdl| CShuffleNXdl| CDEBlockTransfer| CDEBlockTransfer| - // ##############################################| Spatial| | | | | Type| Type| Type| DataType| Type| Type| Operation| Operation| Operation| DataSpecialization| GemmM| GemmN| PrefetchStage| Size| Block| Block| Block| | | XDL| XDL| PerWave| PerWave| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraN| PerWave| PerWave| _MBlock_MPerBlock| ScalarPerVector| - // ##############################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | Lengths_AK0_M_AK1| ArrangeOrder| | | PerVector| PerVector_AK1| | Lengths_BK0_N_BK1| ArrangeOrder| | | PerVector| PerVector_BK1| | PerShuffle| PerShuffle| _NBlock_NPerBlock| _NPerBlock| - // ##############################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | - DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1< 2, GNHWK, GKYXC, Empty_Tuple, GNHWC, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, ConvBwdDataFilter1x1Stride1Pad0, true, true, 1, 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 1, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, - DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1< 2, GNHWK, GKYXC, Empty_Tuple, GNHWC, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, ConvBwdDataFilter1x1Stride1Pad0, true, true, 1, 256, 128, 256, 32, 8, 8, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 1, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, - DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1< 2, GNHWK, GKYXC, Empty_Tuple, GNHWC, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, ConvBwdDataFilter1x1Stride1Pad0, true, true, 1, 128, 128, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 1, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 8>, - DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1< 2, GNHWK, GKYXC, Empty_Tuple, GNHWC, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, ConvBwdDataFilter1x1Stride1Pad0, true, true, 1, 256, 128, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 1, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, - DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1< 2, GNHWK, GKYXC, Empty_Tuple, GNHWC, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, ConvBwdDataFilter1x1Stride1Pad0, true, true, 1, 128, 128, 64, 32, 8, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 1, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 8>, - DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1< 2, GNHWK, GKYXC, Empty_Tuple, GNHWC, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, ConvBwdDataFilter1x1Stride1Pad0, true, true, 1, 128, 64, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 1, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 8>, - DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1< 2, GNHWK, GKYXC, Empty_Tuple, GNHWC, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, ConvBwdDataFilter1x1Stride1Pad0, true, true, 1, 64, 64, 64, 32, 8, 8, 32, 32, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 1, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 8>, - DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1< 2, GNHWK, GKYXC, Empty_Tuple, GNHWC, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, ConvBwdDataFilter1x1Stride1Pad0, true, true, 1, 256, 128, 64, 32, 8, 8, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 1, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, - DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1< 2, GNHWK, GKYXC, Empty_Tuple, GNHWC, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, ConvBwdDataFilter1x1Stride1Pad0, true, true, 1, 256, 64, 128, 32, 8, 8, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 1, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, - DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1< 2, GNHWK, GKYXC, Empty_Tuple, GNHWC, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, ConvBwdDataFilter1x1Stride1Pad0, true, true, 1, 128, 128, 32, 32, 8, 8, 32, 32, 2, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 1, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 8>, - DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1< 2, GNHWK, GKYXC, Empty_Tuple, GNHWC, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, ConvBwdDataFilter1x1Stride1Pad0, true, true, 1, 128, 32, 128, 32, 8, 8, 32, 32, 1, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 1, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 8>, - DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1< 2, GNHWK, GKYXC, Empty_Tuple, GNHWC, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, ConvBwdDataFilter1x1Stride1Pad0, true, true, 1, 64, 64, 32, 32, 8, 8, 32, 32, 2, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 1, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 8>, - DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1< 2, GNHWK, GKYXC, Empty_Tuple, GNHWC, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, ConvBwdDataFilter1x1Stride1Pad0, true, true, 1, 64, 32, 64, 32, 8, 8, 32, 32, 1, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 1, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 8> - // clang-format on - >; - -void add_device_grouped_conv2d_bwd_data_xdl_gnhwc_gkyxc_gnhwk_f16_instances( +// Compilation parameters for out[g, n, hi, wi, c] * wei[g, k, y, x, c] = in[g, n, ho, wo, k] +void add_device_grouped_conv2d_bwd_data_xdl_gnhwk_gkyxc_gnhwc_f16_instances( std::vector>>& instances) { + // 1. Default + add_device_operation_instances( + instances, + device_grouped_conv_bwd_data_xdl_f16_instances<2, + GNHWK, + GKYXC, + Empty_Tuple, + GNHWC, + ConvBwdDataDefault>{}); + // 2. Filter1x1Stride1Pad0 add_device_operation_instances( - instances, device_grouped_conv2d_bwd_data_xdl_gnhwc_gkyxc_gnhwk_f16_instances{}); + instances, + device_grouped_conv_bwd_data_xdl_f16_instances<2, + GNHWK, + GKYXC, + Empty_Tuple, + GNHWC, + ConvBwdDataFilter1x1Stride1Pad0>{}); } } // namespace instance diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_data/device_grouped_conv2d_bwd_data_xdl_gnhwc_gkyxc_gnhwk_f32_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_data/device_grouped_conv2d_bwd_data_xdl_gnhwc_gkyxc_gnhwk_f32_instance.cpp new file mode 100644 index 0000000000000000000000000000000000000000..f9351d96f24ba6489ada19a4bc560df77a606a9e --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_data/device_grouped_conv2d_bwd_data_xdl_gnhwc_gkyxc_gnhwk_f32_instance.cpp @@ -0,0 +1,49 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" +#include "ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_data/device_grouped_conv_bwd_data_xdl_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { +// Compilation parameters for out[g, n, hi, wi, c] * wei[g, k, y, x, c] = in[g, n, ho, wo, k] +void add_device_grouped_conv2d_bwd_data_xdl_gnhwk_gkyxc_gnhwc_f32_instances( + std::vector>>& instances) +{ + // 1. Default + add_device_operation_instances( + instances, + device_grouped_conv_bwd_data_xdl_f32_instances<2, + GNHWK, + GKYXC, + Empty_Tuple, + GNHWC, + ConvBwdDataDefault>{}); + // 2. Filter1x1Stride1Pad0 + add_device_operation_instances( + instances, + device_grouped_conv_bwd_data_xdl_f32_instances<2, + GNHWK, + GKYXC, + Empty_Tuple, + GNHWC, + ConvBwdDataFilter1x1Stride1Pad0>{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_data/device_grouped_conv2d_bwd_data_xdl_nhwgc_gkyxc_nhwgk_bf16_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_data/device_grouped_conv2d_bwd_data_xdl_nhwgc_gkyxc_nhwgk_bf16_instance.cpp new file mode 100644 index 0000000000000000000000000000000000000000..5d9194798b67fc7daf81b9931f0945008383da12 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_data/device_grouped_conv2d_bwd_data_xdl_nhwgc_gkyxc_nhwgk_bf16_instance.cpp @@ -0,0 +1,49 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" +#include "ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_data/device_grouped_conv_bwd_data_xdl_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { +// Compilation parameters for out[n, hi, wi, g, c] * wei[g, k, y, x, c] = in[n, ho, wo, g, k] +void add_device_grouped_conv2d_bwd_data_xdl_nhwgk_gkyxc_nhwgc_bf16_instances( + std::vector>>& instances) +{ + // 1. Default + add_device_operation_instances( + instances, + device_grouped_conv_bwd_data_xdl_bf16_instances<2, + NHWGK, + GKYXC, + Empty_Tuple, + NHWGC, + ConvBwdDataDefault>{}); + // 2. Filter1x1Stride1Pad0 + add_device_operation_instances( + instances, + device_grouped_conv_bwd_data_xdl_bf16_instances<2, + NHWGK, + GKYXC, + Empty_Tuple, + NHWGC, + ConvBwdDataFilter1x1Stride1Pad0>{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_data/device_grouped_conv2d_bwd_data_xdl_nhwgc_gkyxc_nhwgk_f16_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_data/device_grouped_conv2d_bwd_data_xdl_nhwgc_gkyxc_nhwgk_f16_instance.cpp new file mode 100644 index 0000000000000000000000000000000000000000..5269bb9652e689f64cf89eed8fc4200d8c18b467 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_data/device_grouped_conv2d_bwd_data_xdl_nhwgc_gkyxc_nhwgk_f16_instance.cpp @@ -0,0 +1,49 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" +#include "ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_data/device_grouped_conv_bwd_data_xdl_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { +// Compilation parameters for out[n, hi, wi, g, c] * wei[g, k, y, x, c] = in[n, ho, wo, g, k] +void add_device_grouped_conv2d_bwd_data_xdl_nhwgk_gkyxc_nhwgc_f16_instances( + std::vector>>& instances) +{ + // 1. Default + add_device_operation_instances( + instances, + device_grouped_conv_bwd_data_xdl_f16_instances<2, + NHWGK, + GKYXC, + Empty_Tuple, + NHWGC, + ConvBwdDataDefault>{}); + // 2. Filter1x1Stride1Pad0 + add_device_operation_instances( + instances, + device_grouped_conv_bwd_data_xdl_f16_instances<2, + NHWGK, + GKYXC, + Empty_Tuple, + NHWGC, + ConvBwdDataFilter1x1Stride1Pad0>{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_data/device_grouped_conv2d_bwd_data_xdl_nhwgc_gkyxc_nhwgk_f32_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_data/device_grouped_conv2d_bwd_data_xdl_nhwgc_gkyxc_nhwgk_f32_instance.cpp new file mode 100644 index 0000000000000000000000000000000000000000..adfa08c1f74646d4b3cda6dcd38b2dbcfbea2684 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_data/device_grouped_conv2d_bwd_data_xdl_nhwgc_gkyxc_nhwgk_f32_instance.cpp @@ -0,0 +1,49 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" +#include "ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_data/device_grouped_conv_bwd_data_xdl_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { +// Compilation parameters for out[n, hi, wi, g, c] * wei[g, k, y, x, c] = in[n, ho, wo, g, k] +void add_device_grouped_conv2d_bwd_data_xdl_nhwgk_gkyxc_nhwgc_f32_instances( + std::vector>>& instances) +{ + // 1. Default + add_device_operation_instances( + instances, + device_grouped_conv_bwd_data_xdl_f32_instances<2, + NHWGK, + GKYXC, + Empty_Tuple, + NHWGC, + ConvBwdDataDefault>{}); + // 2. Filter1x1Stride1Pad0 + add_device_operation_instances( + instances, + device_grouped_conv_bwd_data_xdl_f32_instances<2, + NHWGK, + GKYXC, + Empty_Tuple, + NHWGC, + ConvBwdDataFilter1x1Stride1Pad0>{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_weight/CMakeLists.txt b/library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_weight/CMakeLists.txt index 4009121e7fb12084e94308e3a1b1070ce9f72b11..b7b9fc92d1263f4537ebdfe976b811d6c45757cb 100644 --- a/library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_weight/CMakeLists.txt +++ b/library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_weight/CMakeLists.txt @@ -2,5 +2,8 @@ add_instance_library(device_grouped_conv2d_bwd_weight_instance device_grouped_conv2d_bwd_weight_xdl_gnhwc_gkyxc_gnhwk_f16_instance.cpp device_grouped_conv2d_bwd_weight_xdl_gnhwc_gkyxc_gnhwk_f32_instance.cpp device_grouped_conv2d_bwd_weight_xdl_gnhwc_gkyxc_gnhwk_bf16_instance.cpp + device_grouped_conv2d_bwd_weight_xdl_nhwgc_gkyxc_nhwgk_f16_instance.cpp + device_grouped_conv2d_bwd_weight_xdl_nhwgc_gkyxc_nhwgk_f32_instance.cpp + device_grouped_conv2d_bwd_weight_xdl_nhwgc_gkyxc_nhwgk_bf16_instance.cpp ) diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_weight/device_grouped_conv2d_bwd_weight_xdl_gnhwc_gkyxc_gnhwk_bf16_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_weight/device_grouped_conv2d_bwd_weight_xdl_gnhwc_gkyxc_gnhwk_bf16_instance.cpp index ede21f1f4f7124156540518bf4252d3526929606..cf39c8601033c41e80f6ed6340d76d874ee5a6cd 100644 --- a/library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_weight/device_grouped_conv2d_bwd_weight_xdl_gnhwc_gkyxc_gnhwk_bf16_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_weight/device_grouped_conv2d_bwd_weight_xdl_gnhwc_gkyxc_gnhwk_bf16_instance.cpp @@ -1,85 +1,16 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. - -#include - -#include "ck/ck.hpp" -#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" -#include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_gnwc_gkxc_gnwk_xdl_cshuffle.hpp" -#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" +#include "ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_weight/device_grouped_conv_bwd_weight_xdl_instance.hpp" + namespace ck { namespace tensor_operation { namespace device { namespace instance { -using BF16 = bhalf_t; -using F32 = float; - -template -using S = ck::Sequence; - -using GNHWC = ck::tensor_layout::convolution::GNHWC; -using GKYXC = ck::tensor_layout::convolution::GKYXC; -using GNHWK = ck::tensor_layout::convolution::GNHWK; - -using PassThrough = ck::tensor_operation::element_wise::PassThrough; - -static constexpr auto ConvBwdWeightDefault = - ck::tensor_operation::device::ConvolutionBackwardWeightSpecialization::Default; - -static constexpr auto ConvBwdWeightFilter1x1Stride1Pad0 = - ck::tensor_operation::device::ConvolutionBackwardWeightSpecialization::Filter1x1Stride1Pad0; - -// Compilation parameters for in[n, hi, wi, c] * wei[k, y, x, c] = out[n, ho, wo, k] -using device_grouped_conv2d_bwd_weight_xdl_c_shuffle_gnhwc_gkyxc_gnhwk_bf16_f32_bf16_instances = - std::tuple< - // clang-format off - //#########################################| Num| InData| WeiData| OutData| AccData| In| Wei| Out| ConvBackward| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransfer| CBlockTransfer| - //#########################################| Dim| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise| Weight| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| ClusterLengths| ScalarPerVector| - //#########################################| Spatial| | | | | Operation| Operation| Operation| Specialization| | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| MBlock_MPerBlock| NWaveNPerXdl| - //#########################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | NBlock_NPerBlock| | - DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Xdl_CShuffle< 2, BF16, F32, BF16, F32, PassThrough, PassThrough, PassThrough, ConvBwdWeightDefault, 256, 256, 128, 4, 8, 32, 32, 4, 2, S<1, 4, 32, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, S<1, 4, 16, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 2, true, 1, 1, S<1, 32, 1, 8>, 4>, - DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Xdl_CShuffle< 2, BF16, F32, BF16, F32, PassThrough, PassThrough, PassThrough, ConvBwdWeightDefault, 256, 128, 256, 4, 8, 32, 32, 2, 4, S<1, 4, 16, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 2, true, S<1, 4, 32, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, 1, 1, S<1, 32, 1, 8>, 4>, - DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Xdl_CShuffle< 2, BF16, F32, BF16, F32, PassThrough, PassThrough, PassThrough, ConvBwdWeightDefault, 128, 128, 128, 4, 8, 32, 32, 4, 2, S<1, 4, 16, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, S<1, 4, 16, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, 1, 1, S<1, 32, 1, 4>, 4>, - DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Xdl_CShuffle< 2, BF16, F32, BF16, F32, PassThrough, PassThrough, PassThrough, ConvBwdWeightDefault, 256, 128, 128, 4, 8, 32, 32, 2, 2, S<1, 4, 16, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 2, true, S<1, 4, 16, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 2, true, 1, 1, S<1, 32, 1, 4>, 4>, - DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Xdl_CShuffle< 2, BF16, F32, BF16, F32, PassThrough, PassThrough, PassThrough, ConvBwdWeightDefault, 128, 128, 64, 4, 8, 32, 32, 2, 2, S<1, 4, 16, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, S<1, 4, 8, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 2, true, 1, 1, S<1, 32, 1, 4>, 4>, - DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Xdl_CShuffle< 2, BF16, F32, BF16, F32, PassThrough, PassThrough, PassThrough, ConvBwdWeightDefault, 128, 64, 128, 4, 8, 32, 32, 2, 2, S<1, 4, 8, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 2, true, S<1, 4, 16, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, 1, 1, S<1, 32, 1, 4>, 4>, - DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Xdl_CShuffle< 2, BF16, F32, BF16, F32, PassThrough, PassThrough, PassThrough, ConvBwdWeightDefault, 64, 64, 64, 4, 8, 32, 32, 2, 2, S<1, 4, 8, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, S<1, 4, 8, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, 1, 1, S<1, 16, 1, 4>, 4>, - DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Xdl_CShuffle< 2, BF16, F32, BF16, F32, PassThrough, PassThrough, PassThrough, ConvBwdWeightDefault, 256, 128, 64, 4, 8, 32, 32, 2, 1, S<1, 4, 16, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 2, true, S<1, 4, 8, 8>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 1, true, 1, 1, S<1, 32, 1, 4>, 4>, - DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Xdl_CShuffle< 2, BF16, F32, BF16, F32, PassThrough, PassThrough, PassThrough, ConvBwdWeightDefault, 256, 64, 128, 4, 8, 32, 32, 1, 2, S<1, 4, 8, 8>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 1, true, S<1, 4, 16, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 2, true, 1, 1, S<1, 32, 1, 4>, 4>, - DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Xdl_CShuffle< 2, BF16, F32, BF16, F32, PassThrough, PassThrough, PassThrough, ConvBwdWeightDefault, 128, 128, 32, 4, 8, 32, 32, 2, 1, S<1, 4, 16, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, S<1, 4, 4, 8>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 1, true, 1, 1, S<1, 32, 1, 4>, 4>, - DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Xdl_CShuffle< 2, BF16, F32, BF16, F32, PassThrough, PassThrough, PassThrough, ConvBwdWeightDefault, 128, 32, 128, 4, 8, 32, 32, 1, 2, S<1, 4, 4, 8>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 1, true, S<1, 4, 16, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, 1, 1, S<1, 32, 1, 4>, 4>, - DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Xdl_CShuffle< 2, BF16, F32, BF16, F32, PassThrough, PassThrough, PassThrough, ConvBwdWeightDefault, 64, 64, 32, 4, 8, 32, 32, 2, 1, S<1, 4, 8, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, S<1, 4, 4, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 2, true, 1, 1, S<1, 16, 1, 4>, 4>, - DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Xdl_CShuffle< 2, BF16, F32, BF16, F32, PassThrough, PassThrough, PassThrough, ConvBwdWeightDefault, 64, 32, 64, 4, 8, 32, 32, 1, 2, S<1, 4, 4, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 2, true, S<1, 4, 8, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, 1, 1, S<1, 16, 1, 4>, 4> - // clang-format on - >; - -using device_grouped_conv2d_bwd_weight_xdl_gnhwc_gkyxc_gnhwk_1x1_s1_p0_bf16_f32_bf16_instances = - std::tuple< - // clang-format off - //#########################################| Num| InData| WeiData| OutData| AccData| In| Wei| Out| ConvBackward| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransfer| CBlockTransfer| - //#########################################| Dim| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise| Weight| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| ClusterLengths| ScalarPerVector| - //#########################################| Spatial| | | | | Operation| Operation| Operation| Specialization| | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| MBlock_MPerBlock| NWaveNPerXdl| - //#########################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | NBlock_NPerBlock| | - DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Xdl_CShuffle< 2, BF16, F32, BF16, F32, PassThrough, PassThrough, PassThrough, ConvBwdWeightFilter1x1Stride1Pad0, 256, 256, 128, 4, 8, 32, 32, 4, 2, S<1, 4, 32, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, S<1, 4, 16, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 2, true, 1, 1, S<1, 32, 1, 8>, 4>, - DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Xdl_CShuffle< 2, BF16, F32, BF16, F32, PassThrough, PassThrough, PassThrough, ConvBwdWeightFilter1x1Stride1Pad0, 256, 128, 256, 4, 8, 32, 32, 2, 4, S<1, 4, 16, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 2, true, S<1, 4, 32, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, 1, 1, S<1, 32, 1, 8>, 4>, - DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Xdl_CShuffle< 2, BF16, F32, BF16, F32, PassThrough, PassThrough, PassThrough, ConvBwdWeightFilter1x1Stride1Pad0, 128, 128, 128, 4, 8, 32, 32, 4, 2, S<1, 4, 16, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, S<1, 4, 16, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, 1, 1, S<1, 32, 1, 4>, 4>, - DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Xdl_CShuffle< 2, BF16, F32, BF16, F32, PassThrough, PassThrough, PassThrough, ConvBwdWeightFilter1x1Stride1Pad0, 256, 128, 128, 4, 8, 32, 32, 2, 2, S<1, 4, 16, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 2, true, S<1, 4, 16, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 2, true, 1, 1, S<1, 32, 1, 4>, 4>, - DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Xdl_CShuffle< 2, BF16, F32, BF16, F32, PassThrough, PassThrough, PassThrough, ConvBwdWeightFilter1x1Stride1Pad0, 128, 128, 64, 4, 8, 32, 32, 2, 2, S<1, 4, 16, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, S<1, 4, 8, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 2, true, 1, 1, S<1, 32, 1, 4>, 4>, - DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Xdl_CShuffle< 2, BF16, F32, BF16, F32, PassThrough, PassThrough, PassThrough, ConvBwdWeightFilter1x1Stride1Pad0, 128, 64, 128, 4, 8, 32, 32, 2, 2, S<1, 4, 8, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 2, true, S<1, 4, 16, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, 1, 1, S<1, 32, 1, 4>, 4>, - DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Xdl_CShuffle< 2, BF16, F32, BF16, F32, PassThrough, PassThrough, PassThrough, ConvBwdWeightFilter1x1Stride1Pad0, 64, 64, 64, 4, 8, 32, 32, 2, 2, S<1, 4, 8, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, S<1, 4, 8, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, 1, 1, S<1, 16, 1, 4>, 4>, - DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Xdl_CShuffle< 2, BF16, F32, BF16, F32, PassThrough, PassThrough, PassThrough, ConvBwdWeightFilter1x1Stride1Pad0, 256, 128, 64, 4, 8, 32, 32, 2, 1, S<1, 4, 16, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 2, true, S<1, 4, 8, 8>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 1, true, 1, 1, S<1, 32, 1, 4>, 4>, - DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Xdl_CShuffle< 2, BF16, F32, BF16, F32, PassThrough, PassThrough, PassThrough, ConvBwdWeightFilter1x1Stride1Pad0, 256, 64, 128, 4, 8, 32, 32, 1, 2, S<1, 4, 8, 8>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 1, true, S<1, 4, 16, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 2, true, 1, 1, S<1, 32, 1, 4>, 4>, - DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Xdl_CShuffle< 2, BF16, F32, BF16, F32, PassThrough, PassThrough, PassThrough, ConvBwdWeightFilter1x1Stride1Pad0, 128, 128, 32, 4, 8, 32, 32, 2, 1, S<1, 4, 16, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, S<1, 4, 4, 8>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 1, true, 1, 1, S<1, 32, 1, 4>, 4>, - DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Xdl_CShuffle< 2, BF16, F32, BF16, F32, PassThrough, PassThrough, PassThrough, ConvBwdWeightFilter1x1Stride1Pad0, 128, 32, 128, 4, 8, 32, 32, 1, 2, S<1, 4, 4, 8>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 1, true, S<1, 4, 16, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, 1, 1, S<1, 32, 1, 4>, 4>, - DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Xdl_CShuffle< 2, BF16, F32, BF16, F32, PassThrough, PassThrough, PassThrough, ConvBwdWeightFilter1x1Stride1Pad0, 64, 64, 32, 4, 8, 32, 32, 2, 1, S<1, 4, 8, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, S<1, 4, 4, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 2, true, 1, 1, S<1, 16, 1, 4>, 4>, - DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Xdl_CShuffle< 2, BF16, F32, BF16, F32, PassThrough, PassThrough, PassThrough, ConvBwdWeightFilter1x1Stride1Pad0, 64, 32, 64, 4, 8, 32, 32, 1, 2, S<1, 4, 4, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 2, true, S<1, 4, 8, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, 1, 1, S<1, 16, 1, 4>, 4> - // clang-format on - >; - +// Compilation parameters for in[g, n, hi, wi, c] * wei[g, k, y, x, c] = out[g, n, ho, wo, k] void add_device_grouped_conv2d_bwd_weight_xdl_gnhwc_gkyxc_gnhwk_bf16_f32_bf16_instances( std::vector>>& instances) { + // 1. Default add_device_operation_instances( instances, - device_grouped_conv2d_bwd_weight_xdl_c_shuffle_gnhwc_gkyxc_gnhwk_bf16_f32_bf16_instances{}); - add_device_operation_instances( - instances, - device_grouped_conv2d_bwd_weight_xdl_gnhwc_gkyxc_gnhwk_1x1_s1_p0_bf16_f32_bf16_instances{}); + device_grouped_conv_bwd_weight_xdl_c_shuffle_bf16_instances<2, + GNHWC, + GKYXC, + GNHWK, + ConvBwdWeightDefault>{}); + // 2. Filter1x1Stride1Pad0 + add_device_operation_instances(instances, + device_grouped_conv_bwd_weight_xdl_c_shuffle_bf16_instances< + 2, + GNHWC, + GKYXC, + GNHWK, + ConvBwdWeightFilter1x1Stride1Pad0>{}); } } // namespace instance diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_weight/device_grouped_conv2d_bwd_weight_xdl_gnhwc_gkyxc_gnhwk_f16_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_weight/device_grouped_conv2d_bwd_weight_xdl_gnhwc_gkyxc_gnhwk_f16_instance.cpp index 99e556618c3bfc535807ae8595a33e0f6b6d5390..d52f0b4d1c6bb99d63852b62b3c54cef29f3221b 100644 --- a/library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_weight/device_grouped_conv2d_bwd_weight_xdl_gnhwc_gkyxc_gnhwk_f16_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_weight/device_grouped_conv2d_bwd_weight_xdl_gnhwc_gkyxc_gnhwk_f16_instance.cpp @@ -1,84 +1,16 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. - -#include - -#include "ck/ck.hpp" -#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" -#include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_gnwc_gkxc_gnwk_xdl_cshuffle.hpp" -#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" +#include "ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_weight/device_grouped_conv_bwd_weight_xdl_instance.hpp" + namespace ck { namespace tensor_operation { namespace device { namespace instance { -using F16 = ck::half_t; -using F32 = float; - -template -using S = ck::Sequence; - -using GNHWC = ck::tensor_layout::convolution::GNHWC; -using GKYXC = ck::tensor_layout::convolution::GKYXC; -using GNHWK = ck::tensor_layout::convolution::GNHWK; - -using PassThrough = ck::tensor_operation::element_wise::PassThrough; - -static constexpr auto ConvBwdWeightDefault = - ck::tensor_operation::device::ConvolutionBackwardWeightSpecialization::Default; - -static constexpr auto ConvBwdWeightFilter1x1Stride1Pad0 = - ck::tensor_operation::device::ConvolutionBackwardWeightSpecialization::Filter1x1Stride1Pad0; - -// Compilation parameters for in[n, hi, wi, c] * wei[k, y, x, c] = out[n, ho, wo, k] -using device_grouped_conv2d_bwd_weight_xdl_c_shuffle_gnhwc_gkyxc_gnhwk_f16_default_instances = - std::tuple< - // clang-format off - //#########################################| Num| InData| WeiData| OutData| AccData| In| Wei| Out| ConvBackward| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransfer| CBlockTransfer| - //#########################################| Dim| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise| Weight| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| ClusterLengths| ScalarPerVector| - //#########################################| Spatial| | | | | Operation| Operation| Operation| Specialization| | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| MBlock_MPerBlock| NWaveNPerXdl| - //#########################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | NBlock_NPerBlock| | - DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Xdl_CShuffle< 2, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvBwdWeightDefault, 256, 256, 128, 4, 8, 32, 32, 4, 2, S<1, 4, 32, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, S<1, 4, 16, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 2, true, 1, 1, S<1, 32, 1, 8>, 8>, - DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Xdl_CShuffle< 2, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvBwdWeightDefault, 256, 128, 256, 4, 8, 32, 32, 2, 4, S<1, 4, 16, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 2, true, S<1, 4, 32, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, 1, 1, S<1, 32, 1, 8>, 8>, - DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Xdl_CShuffle< 2, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvBwdWeightDefault, 128, 128, 128, 4, 8, 32, 32, 4, 2, S<1, 4, 16, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, S<1, 4, 16, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, 1, 1, S<1, 32, 1, 4>, 8>, - DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Xdl_CShuffle< 2, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvBwdWeightDefault, 256, 128, 128, 4, 8, 32, 32, 2, 2, S<1, 4, 16, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 2, true, S<1, 4, 16, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 2, true, 1, 1, S<1, 32, 1, 4>, 8>, - DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Xdl_CShuffle< 2, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvBwdWeightDefault, 128, 128, 64, 4, 8, 32, 32, 2, 2, S<1, 4, 16, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, S<1, 4, 8, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 2, true, 1, 1, S<1, 32, 1, 4>, 8>, - DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Xdl_CShuffle< 2, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvBwdWeightDefault, 128, 64, 128, 4, 8, 32, 32, 2, 2, S<1, 4, 8, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 2, true, S<1, 4, 16, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, 1, 1, S<1, 32, 1, 4>, 8>, - DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Xdl_CShuffle< 2, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvBwdWeightDefault, 64, 64, 64, 4, 8, 32, 32, 2, 2, S<1, 4, 8, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, S<1, 4, 8, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, 1, 1, S<1, 16, 1, 4>, 8>, - DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Xdl_CShuffle< 2, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvBwdWeightDefault, 256, 128, 64, 4, 8, 32, 32, 2, 1, S<1, 4, 16, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 2, true, S<1, 4, 8, 8>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 1, true, 1, 1, S<1, 32, 1, 4>, 8>, - DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Xdl_CShuffle< 2, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvBwdWeightDefault, 256, 64, 128, 4, 8, 32, 32, 1, 2, S<1, 4, 8, 8>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 1, true, S<1, 4, 16, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 2, true, 1, 1, S<1, 32, 1, 4>, 8>, - DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Xdl_CShuffle< 2, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvBwdWeightDefault, 128, 128, 32, 4, 8, 32, 32, 2, 1, S<1, 4, 16, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, S<1, 4, 4, 8>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 1, true, 1, 1, S<1, 32, 1, 4>, 8>, - DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Xdl_CShuffle< 2, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvBwdWeightDefault, 128, 32, 128, 4, 8, 32, 32, 1, 2, S<1, 4, 4, 8>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 1, true, S<1, 4, 16, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, 1, 1, S<1, 32, 1, 4>, 8>, - DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Xdl_CShuffle< 2, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvBwdWeightDefault, 64, 64, 32, 4, 8, 32, 32, 2, 1, S<1, 4, 8, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, S<1, 4, 4, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 2, true, 1, 1, S<1, 16, 1, 4>, 8>, - DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Xdl_CShuffle< 2, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvBwdWeightDefault, 64, 32, 64, 4, 8, 32, 32, 1, 2, S<1, 4, 4, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 2, true, S<1, 4, 8, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, 1, 1, S<1, 16, 1, 4>, 8> - // clang-format on - >; - -using device_grouped_conv2d_bwd_weight_xdl_gnhwc_gkyxc_gnhwk_1x1_s1_p0_f16_instances = std::tuple< - // clang-format off - //#########################################| Num| InData| WeiData| OutData| AccData| In| Wei| Out| ConvBackward| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransfer| CBlockTransfer| - //#########################################| Dim| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise| Weight| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| ClusterLengths| ScalarPerVector| - //#########################################| Spatial| | | | | Operation| Operation| Operation| Specialization| | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| MBlock_MPerBlock| NWaveNPerXdl| - //#########################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | NBlock_NPerBlock| | - DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Xdl_CShuffle< 2, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvBwdWeightFilter1x1Stride1Pad0, 256, 256, 128, 4, 8, 32, 32, 4, 2, S<1, 4, 32, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, S<1, 4, 16, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 2, true, 1, 1, S<1, 32, 1, 8>, 8>, - DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Xdl_CShuffle< 2, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvBwdWeightFilter1x1Stride1Pad0, 256, 128, 256, 4, 8, 32, 32, 2, 4, S<1, 4, 16, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 2, true, S<1, 4, 32, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, 1, 1, S<1, 32, 1, 8>, 8>, - DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Xdl_CShuffle< 2, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvBwdWeightFilter1x1Stride1Pad0, 128, 128, 128, 4, 8, 32, 32, 4, 2, S<1, 4, 16, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, S<1, 4, 16, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, 1, 1, S<1, 32, 1, 4>, 8>, - DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Xdl_CShuffle< 2, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvBwdWeightFilter1x1Stride1Pad0, 256, 128, 128, 4, 8, 32, 32, 2, 2, S<1, 4, 16, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 2, true, S<1, 4, 16, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 2, true, 1, 1, S<1, 32, 1, 4>, 8>, - DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Xdl_CShuffle< 2, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvBwdWeightFilter1x1Stride1Pad0, 128, 128, 64, 4, 8, 32, 32, 2, 2, S<1, 4, 16, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, S<1, 4, 8, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 2, true, 1, 1, S<1, 32, 1, 4>, 8>, - DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Xdl_CShuffle< 2, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvBwdWeightFilter1x1Stride1Pad0, 128, 64, 128, 4, 8, 32, 32, 2, 2, S<1, 4, 8, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 2, true, S<1, 4, 16, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, 1, 1, S<1, 32, 1, 4>, 8>, - DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Xdl_CShuffle< 2, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvBwdWeightFilter1x1Stride1Pad0, 64, 64, 64, 4, 8, 32, 32, 2, 2, S<1, 4, 8, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, S<1, 4, 8, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, 1, 1, S<1, 16, 1, 4>, 8>, - DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Xdl_CShuffle< 2, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvBwdWeightFilter1x1Stride1Pad0, 256, 128, 64, 4, 8, 32, 32, 2, 1, S<1, 4, 16, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 2, true, S<1, 4, 8, 8>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 1, true, 1, 1, S<1, 32, 1, 4>, 8>, - DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Xdl_CShuffle< 2, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvBwdWeightFilter1x1Stride1Pad0, 256, 64, 128, 4, 8, 32, 32, 1, 2, S<1, 4, 8, 8>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 1, true, S<1, 4, 16, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 2, true, 1, 1, S<1, 32, 1, 4>, 8>, - DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Xdl_CShuffle< 2, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvBwdWeightFilter1x1Stride1Pad0, 128, 128, 32, 4, 8, 32, 32, 2, 1, S<1, 4, 16, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, S<1, 4, 4, 8>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 1, true, 1, 1, S<1, 32, 1, 4>, 8>, - DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Xdl_CShuffle< 2, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvBwdWeightFilter1x1Stride1Pad0, 128, 32, 128, 4, 8, 32, 32, 1, 2, S<1, 4, 4, 8>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 1, true, S<1, 4, 16, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, 1, 1, S<1, 32, 1, 4>, 8>, - DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Xdl_CShuffle< 2, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvBwdWeightFilter1x1Stride1Pad0, 64, 64, 32, 4, 8, 32, 32, 2, 1, S<1, 4, 8, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, S<1, 4, 4, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 2, true, 1, 1, S<1, 16, 1, 4>, 8>, - DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Xdl_CShuffle< 2, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvBwdWeightFilter1x1Stride1Pad0, 64, 32, 64, 4, 8, 32, 32, 1, 2, S<1, 4, 4, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 2, true, S<1, 4, 8, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, 1, 1, S<1, 16, 1, 4>, 8> - // clang-format on - >; - +// Compilation parameters for in[g, n, hi, wi, c] * wei[g, k, y, x, c] = out[g, n, ho, wo, k] void add_device_grouped_conv2d_bwd_weight_xdl_gnhwc_gkyxc_gnhwk_f16_instances( std::vector>>& instances) { + // 1. Default add_device_operation_instances( instances, - device_grouped_conv2d_bwd_weight_xdl_c_shuffle_gnhwc_gkyxc_gnhwk_f16_default_instances{}); - add_device_operation_instances( - instances, - device_grouped_conv2d_bwd_weight_xdl_gnhwc_gkyxc_gnhwk_1x1_s1_p0_f16_instances{}); + device_grouped_conv_bwd_weight_xdl_c_shuffle_f16_instances<2, + GNHWC, + GKYXC, + GNHWK, + ConvBwdWeightDefault>{}); + // 2. Filter1x1Stride1Pad0 + add_device_operation_instances(instances, + device_grouped_conv_bwd_weight_xdl_c_shuffle_f16_instances< + 2, + GNHWC, + GKYXC, + GNHWK, + ConvBwdWeightFilter1x1Stride1Pad0>{}); } } // namespace instance diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_weight/device_grouped_conv2d_bwd_weight_xdl_gnhwc_gkyxc_gnhwk_f32_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_weight/device_grouped_conv2d_bwd_weight_xdl_gnhwc_gkyxc_gnhwk_f32_instance.cpp index 15871a28c3a2a85c053fc3259f9de67b5694ede6..62547a5014d527f87dc242fe41cd2e157bd93617 100644 --- a/library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_weight/device_grouped_conv2d_bwd_weight_xdl_gnhwc_gkyxc_gnhwk_f32_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_weight/device_grouped_conv2d_bwd_weight_xdl_gnhwc_gkyxc_gnhwk_f32_instance.cpp @@ -1,83 +1,16 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. - -#include - -#include "ck/ck.hpp" -#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" -#include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_gnwc_gkxc_gnwk_xdl_cshuffle.hpp" -#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" +#include "ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_weight/device_grouped_conv_bwd_weight_xdl_instance.hpp" + namespace ck { namespace tensor_operation { namespace device { namespace instance { -using F32 = float; - -template -using S = ck::Sequence; - -using GNHWC = ck::tensor_layout::convolution::GNHWC; -using GKYXC = ck::tensor_layout::convolution::GKYXC; -using GNHWK = ck::tensor_layout::convolution::GNHWK; - -using PassThrough = ck::tensor_operation::element_wise::PassThrough; - -static constexpr auto ConvBwdWeightDefault = - ck::tensor_operation::device::ConvolutionBackwardWeightSpecialization::Default; - -static constexpr auto ConvBwdWeightFilter1x1Stride1Pad0 = - ck::tensor_operation::device::ConvolutionBackwardWeightSpecialization::Filter1x1Stride1Pad0; - -// Compilation parameters for in[n, hi, wi, c] * wei[k, y, x, c] = out[n, ho, wo, k] -using device_grouped_conv2d_bwd_weight_xdl_c_shuffle_gnhwc_gkyxc_gnhwk_f32_default_instances = - std::tuple< - // clang-format off - //#########################################| Num| InData| WeiData| OutData| AccData| In| Wei| Out| ConvBackward| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransfer| CBlockTransfer| - //#########################################| Dim| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise| Weight| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| ClusterLengths| ScalarPerVector| - //#########################################| Spatial| | | | | Operation| Operation| Operation| Specialization| | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| MBlock_MPerBlock| NWaveNPerXdl| - //#########################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | NBlock_NPerBlock| | - DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Xdl_CShuffle< 2, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvBwdWeightDefault, 256, 256, 128, 4, 4, 32, 32, 4, 2, S<1, 4, 64, 1>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 4, 4, true, S<1, 4, 32, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 4, 2, true, 1, 1, S<1, 32, 1, 8>, 4>, - DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Xdl_CShuffle< 2, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvBwdWeightDefault, 256, 128, 256, 4, 4, 32, 32, 2, 4, S<1, 4, 32, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 4, 2, true, S<1, 4, 64, 1>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 4, 4, true, 1, 1, S<1, 32, 1, 8>, 4>, - DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Xdl_CShuffle< 2, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvBwdWeightDefault, 128, 128, 128, 4, 4, 32, 32, 4, 2, S<1, 4, 32, 1>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 4, 4, true, S<1, 4, 32, 1>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 4, 4, true, 1, 1, S<1, 32, 1, 4>, 4>, - DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Xdl_CShuffle< 2, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvBwdWeightDefault, 256, 128, 128, 4, 4, 32, 32, 2, 2, S<1, 4, 32, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 4, 2, true, S<1, 4, 32, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 4, 2, true, 1, 1, S<1, 32, 1, 4>, 4>, - DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Xdl_CShuffle< 2, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvBwdWeightDefault, 128, 128, 64, 4, 4, 32, 32, 2, 2, S<1, 4, 32, 1>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 4, 4, true, S<1, 4, 16, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 4, 2, true, 1, 1, S<1, 32, 1, 4>, 4>, - DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Xdl_CShuffle< 2, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvBwdWeightDefault, 128, 64, 128, 4, 4, 32, 32, 2, 2, S<1, 4, 16, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 4, 2, true, S<1, 4, 32, 1>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 4, 4, true, 1, 1, S<1, 32, 1, 4>, 4>, - DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Xdl_CShuffle< 2, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvBwdWeightDefault, 64, 64, 64, 4, 4, 32, 32, 2, 2, S<1, 4, 16, 1>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 4, 4, true, S<1, 4, 16, 1>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 4, 4, true, 1, 1, S<1, 16, 1, 4>, 4>, - DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Xdl_CShuffle< 2, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvBwdWeightDefault, 256, 128, 64, 4, 4, 32, 32, 2, 1, S<1, 4, 32, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 4, 2, true, S<1, 4, 16, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 4, 1, true, 1, 1, S<1, 32, 1, 4>, 4>, - DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Xdl_CShuffle< 2, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvBwdWeightDefault, 256, 64, 128, 4, 4, 32, 32, 1, 2, S<1, 4, 16, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 4, 1, true, S<1, 4, 32, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 4, 2, true, 1, 1, S<1, 32, 1, 4>, 4>, - DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Xdl_CShuffle< 2, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvBwdWeightDefault, 128, 128, 32, 4, 4, 32, 32, 2, 1, S<1, 4, 32, 1>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 4, 4, true, S<1, 4, 8, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 4, 1, true, 1, 1, S<1, 32, 1, 4>, 4>, - DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Xdl_CShuffle< 2, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvBwdWeightDefault, 128, 32, 128, 4, 4, 32, 32, 1, 2, S<1, 4, 8, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 4, 1, true, S<1, 4, 32, 1>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 4, 4, true, 1, 1, S<1, 32, 1, 4>, 4>, - DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Xdl_CShuffle< 2, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvBwdWeightDefault, 64, 64, 32, 4, 4, 32, 32, 2, 1, S<1, 4, 16, 1>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 4, 4, true, S<1, 4, 8, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 4, 2, true, 1, 1, S<1, 16, 1, 4>, 4>, - DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Xdl_CShuffle< 2, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvBwdWeightDefault, 64, 32, 64, 4, 4, 32, 32, 1, 2, S<1, 4, 8, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 4, 2, true, S<1, 4, 16, 1>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 4, 4, true, 1, 1, S<1, 16, 1, 4>, 4> - // clang-format on - >; - -using device_grouped_conv2d_bwd_weight_xdl_gnhwc_gkyxc_gnhwk_1x1_s1_p0_f32_instances = std::tuple< - // clang-format off - //#########################################| Num| InData| WeiData| OutData| AccData| In| Wei| Out| ConvBackward| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransfer| CBlockTransfer| - //#########################################| Dim| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise| Weight| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| ClusterLengths| ScalarPerVector| - //#########################################| Spatial| | | | | Operation| Operation| Operation| Specialization| | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| MBlock_MPerBlock| NWaveNPerXdl| - //#########################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | NBlock_NPerBlock| | - DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Xdl_CShuffle< 2, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvBwdWeightFilter1x1Stride1Pad0, 256, 256, 128, 4, 4, 32, 32, 4, 2, S<1, 4, 64, 1>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 4, 4, true, S<1, 4, 32, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 4, 2, true, 1, 1, S<1, 32, 1, 8>, 4>, - DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Xdl_CShuffle< 2, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvBwdWeightFilter1x1Stride1Pad0, 256, 128, 256, 4, 4, 32, 32, 2, 4, S<1, 4, 32, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 4, 2, true, S<1, 4, 64, 1>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 4, 4, true, 1, 1, S<1, 32, 1, 8>, 4>, - DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Xdl_CShuffle< 2, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvBwdWeightFilter1x1Stride1Pad0, 128, 128, 128, 4, 4, 32, 32, 4, 2, S<1, 4, 32, 1>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 4, 4, true, S<1, 4, 32, 1>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 4, 4, true, 1, 1, S<1, 32, 1, 4>, 4>, - DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Xdl_CShuffle< 2, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvBwdWeightFilter1x1Stride1Pad0, 256, 128, 128, 4, 4, 32, 32, 2, 2, S<1, 4, 32, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 4, 2, true, S<1, 4, 32, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 4, 2, true, 1, 1, S<1, 32, 1, 4>, 4>, - DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Xdl_CShuffle< 2, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvBwdWeightFilter1x1Stride1Pad0, 128, 128, 64, 4, 4, 32, 32, 2, 2, S<1, 4, 32, 1>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 4, 4, true, S<1, 4, 16, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 4, 2, true, 1, 1, S<1, 32, 1, 4>, 4>, - DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Xdl_CShuffle< 2, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvBwdWeightFilter1x1Stride1Pad0, 128, 64, 128, 4, 4, 32, 32, 2, 2, S<1, 4, 16, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 4, 2, true, S<1, 4, 32, 1>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 4, 4, true, 1, 1, S<1, 32, 1, 4>, 4>, - DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Xdl_CShuffle< 2, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvBwdWeightFilter1x1Stride1Pad0, 64, 64, 64, 4, 4, 32, 32, 2, 2, S<1, 4, 16, 1>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 4, 4, true, S<1, 4, 16, 1>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 4, 4, true, 1, 1, S<1, 16, 1, 4>, 4>, - DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Xdl_CShuffle< 2, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvBwdWeightFilter1x1Stride1Pad0, 256, 128, 64, 4, 4, 32, 32, 2, 1, S<1, 4, 32, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 4, 2, true, S<1, 4, 16, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 4, 1, true, 1, 1, S<1, 32, 1, 4>, 4>, - DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Xdl_CShuffle< 2, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvBwdWeightFilter1x1Stride1Pad0, 256, 64, 128, 4, 4, 32, 32, 1, 2, S<1, 4, 16, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 4, 1, true, S<1, 4, 32, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 4, 2, true, 1, 1, S<1, 32, 1, 4>, 4>, - DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Xdl_CShuffle< 2, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvBwdWeightFilter1x1Stride1Pad0, 128, 128, 32, 4, 4, 32, 32, 2, 1, S<1, 4, 32, 1>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 4, 4, true, S<1, 4, 8, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 4, 1, true, 1, 1, S<1, 32, 1, 4>, 4>, - DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Xdl_CShuffle< 2, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvBwdWeightFilter1x1Stride1Pad0, 128, 32, 128, 4, 4, 32, 32, 1, 2, S<1, 4, 8, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 4, 1, true, S<1, 4, 32, 1>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 4, 4, true, 1, 1, S<1, 32, 1, 4>, 4>, - DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Xdl_CShuffle< 2, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvBwdWeightFilter1x1Stride1Pad0, 64, 64, 32, 4, 4, 32, 32, 2, 1, S<1, 4, 16, 1>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 4, 4, true, S<1, 4, 8, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 4, 2, true, 1, 1, S<1, 16, 1, 4>, 4>, - DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Xdl_CShuffle< 2, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvBwdWeightFilter1x1Stride1Pad0, 64, 32, 64, 4, 4, 32, 32, 1, 2, S<1, 4, 8, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 4, 2, true, S<1, 4, 16, 1>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 4, 4, true, 1, 1, S<1, 16, 1, 4>, 4> - // clang-format on - >; - +// Compilation parameters for in[g, n, hi, wi, c] * wei[g, k, y, x, c] = out[g, n, ho, wo, k] void add_device_grouped_conv2d_bwd_weight_xdl_gnhwc_gkyxc_gnhwk_f32_instances( std::vector>>& instances) { + // 1. Default add_device_operation_instances( instances, - device_grouped_conv2d_bwd_weight_xdl_c_shuffle_gnhwc_gkyxc_gnhwk_f32_default_instances{}); - add_device_operation_instances( - instances, - device_grouped_conv2d_bwd_weight_xdl_gnhwc_gkyxc_gnhwk_1x1_s1_p0_f32_instances{}); + device_grouped_conv_bwd_weight_xdl_c_shuffle_f32_instances<2, + GNHWC, + GKYXC, + GNHWK, + ConvBwdWeightDefault>{}); + // 2. Filter1x1Stride1Pad0 + add_device_operation_instances(instances, + device_grouped_conv_bwd_weight_xdl_c_shuffle_f32_instances< + 2, + GNHWC, + GKYXC, + GNHWK, + ConvBwdWeightFilter1x1Stride1Pad0>{}); } } // namespace instance diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_weight/device_grouped_conv2d_bwd_weight_xdl_nhwgc_gkyxc_nhwgk_bf16_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_weight/device_grouped_conv2d_bwd_weight_xdl_nhwgc_gkyxc_nhwgk_bf16_instance.cpp new file mode 100644 index 0000000000000000000000000000000000000000..1cb9991a86cbeecc12fe8813630596299d8f35cb --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_weight/device_grouped_conv2d_bwd_weight_xdl_nhwgc_gkyxc_nhwgk_bf16_instance.cpp @@ -0,0 +1,47 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" + +#include "ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_weight/device_grouped_conv_bwd_weight_xdl_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +// Compilation parameters for in[n, hi, wi, g, c] * wei[g, k, y, x, c] = out[n, ho, wo, g, k] +void add_device_grouped_conv2d_bwd_weight_xdl_nhwgc_gkyxc_nhwgk_bf16_f32_bf16_instances( + std::vector>>& instances) +{ + // 1. Default + add_device_operation_instances( + instances, + device_grouped_conv_bwd_weight_xdl_c_shuffle_bf16_instances<2, + NHWGC, + GKYXC, + NHWGK, + ConvBwdWeightDefault>{}); + // 2. Filter1x1Stride1Pad0 + add_device_operation_instances(instances, + device_grouped_conv_bwd_weight_xdl_c_shuffle_bf16_instances< + 2, + NHWGC, + GKYXC, + NHWGK, + ConvBwdWeightFilter1x1Stride1Pad0>{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_weight/device_grouped_conv2d_bwd_weight_xdl_nhwgc_gkyxc_nhwgk_f16_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_weight/device_grouped_conv2d_bwd_weight_xdl_nhwgc_gkyxc_nhwgk_f16_instance.cpp new file mode 100644 index 0000000000000000000000000000000000000000..e64d55c3bbb986cbf3bc314d1a4bd4e28b5a663e --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_weight/device_grouped_conv2d_bwd_weight_xdl_nhwgc_gkyxc_nhwgk_f16_instance.cpp @@ -0,0 +1,47 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" + +#include "ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_weight/device_grouped_conv_bwd_weight_xdl_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +// Compilation parameters for in[n, hi, wi, g, c] * wei[g, k, y, x, c] = out[n, ho, wo, g, k] +void add_device_grouped_conv2d_bwd_weight_xdl_nhwgc_gkyxc_nhwgk_f16_instances( + std::vector>>& instances) +{ + // 1. Default + add_device_operation_instances( + instances, + device_grouped_conv_bwd_weight_xdl_c_shuffle_f16_instances<2, + NHWGC, + GKYXC, + NHWGK, + ConvBwdWeightDefault>{}); + // 2. Filter1x1Stride1Pad0 + add_device_operation_instances(instances, + device_grouped_conv_bwd_weight_xdl_c_shuffle_f16_instances< + 2, + NHWGC, + GKYXC, + NHWGK, + ConvBwdWeightFilter1x1Stride1Pad0>{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_weight/device_grouped_conv2d_bwd_weight_xdl_nhwgc_gkyxc_nhwgk_f32_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_weight/device_grouped_conv2d_bwd_weight_xdl_nhwgc_gkyxc_nhwgk_f32_instance.cpp new file mode 100644 index 0000000000000000000000000000000000000000..5fa4c9ba3cd1482a81c5aa0ed942737cc3e60b2a --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_weight/device_grouped_conv2d_bwd_weight_xdl_nhwgc_gkyxc_nhwgk_f32_instance.cpp @@ -0,0 +1,47 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" + +#include "ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_weight/device_grouped_conv_bwd_weight_xdl_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +// Compilation parameters for in[n, hi, wi, g, c] * wei[g, k, y, x, c] = out[n, ho, wo, g, k] +void add_device_grouped_conv2d_bwd_weight_xdl_nhwgc_gkyxc_nhwgk_f32_instances( + std::vector>>& instances) +{ + // 1. Default + add_device_operation_instances( + instances, + device_grouped_conv_bwd_weight_xdl_c_shuffle_f32_instances<2, + NHWGC, + GKYXC, + NHWGK, + ConvBwdWeightDefault>{}); + // 2. Filter1x1Stride1Pad0 + add_device_operation_instances(instances, + device_grouped_conv_bwd_weight_xdl_c_shuffle_f32_instances< + 2, + NHWGC, + GKYXC, + NHWGK, + ConvBwdWeightFilter1x1Stride1Pad0>{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/CMakeLists.txt b/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/CMakeLists.txt index a36e1b47caa6c35082d64523941134ecc9ffb320..6b0ad99d69d8d821e9397c0227fcecb40d0fe8e5 100644 --- a/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/CMakeLists.txt +++ b/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/CMakeLists.txt @@ -1,4 +1,5 @@ add_instance_library(device_grouped_conv2d_fwd_instance + #xdl # GNHWC, GKYXC, GNHWK device_grouped_conv2d_fwd_xdl_gnhwc_gkyxc_gnhwk_bf16_instance.cpp device_grouped_conv2d_fwd_xdl_gnhwc_gkyxc_gnhwk_f16_instance.cpp @@ -8,6 +9,13 @@ add_instance_library(device_grouped_conv2d_fwd_instance device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_f16_instance.cpp device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_f32_instance.cpp #dl + # GNHWC, GKYXC, GNHWK device_grouped_conv2d_fwd_dl_gnhwc_gkyxc_gnhwk_f16_instance.cpp device_grouped_conv2d_fwd_dl_gnhwc_gkyxc_gnhwk_f32_instance.cpp + # WMMA + device_grouped_conv2d_fwd_wmma_gnhwc_gkyxc_gnhwk_f16_instance.cpp + device_grouped_conv2d_fwd_wmma_gnhwc_gkyxc_gnhwk_i8_instance.cpp + # NHWGC, GKYXC, NHWGK + device_grouped_conv2d_fwd_dl_nhwgc_gkyxc_nhwgk_f16_instance.cpp + device_grouped_conv2d_fwd_dl_nhwgc_gkyxc_nhwgk_f32_instance.cpp ) diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/device_grouped_conv2d_fwd_common.hpp b/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/device_grouped_conv2d_fwd_common.hpp deleted file mode 100644 index b4de825fb626067c49be31f5f6e2f01a459b4cad..0000000000000000000000000000000000000000 --- a/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/device_grouped_conv2d_fwd_common.hpp +++ /dev/null @@ -1,53 +0,0 @@ -// SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. - -#include - -#include "ck/ck.hpp" -#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" -#include "ck/tensor_operation/gpu/device/convolution_forward_specialization.hpp" -#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" -#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" - -namespace ck { -namespace tensor_operation { -namespace device { -namespace instance { - -using BF16 = ck::bhalf_t; -using F16 = ck::half_t; -using F32 = float; - -using Empty_Tuple = ck::Tuple<>; - -template -using S = ck::Sequence; - -using NHWGC = ck::tensor_layout::convolution::NHWGC; -using GNHWC = ck::tensor_layout::convolution::GNHWC; - -using GKYXC = ck::tensor_layout::convolution::GKYXC; - -using NHWGK = ck::tensor_layout::convolution::NHWGK; -using GNHWK = ck::tensor_layout::convolution::GNHWK; - -using PassThrough = ck::tensor_operation::element_wise::PassThrough; - -static constexpr auto ConvFwdDefault = - ck::tensor_operation::device::ConvolutionForwardSpecialization::Default; - -static constexpr auto ConvFwd1x1P0 = - ck::tensor_operation::device::ConvolutionForwardSpecialization::Filter1x1Pad0; - -static constexpr auto ConvFwd1x1S1P0 = - ck::tensor_operation::device::ConvolutionForwardSpecialization::Filter1x1Stride1Pad0; - -static constexpr auto ConvFwdOddC = - ck::tensor_operation::device::ConvolutionForwardSpecialization::OddC; - -static constexpr auto GemmMNKPadding = ck::tensor_operation::device::GemmSpecialization::MNKPadding; - -} // namespace instance -} // namespace device -} // namespace tensor_operation -} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/device_grouped_conv2d_fwd_dl_gnhwc_gkyxc_gnhwk_f16_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/device_grouped_conv2d_fwd_dl_gnhwc_gkyxc_gnhwk_f16_instance.cpp index f7e575df2b4cea0a6cc5774f5471ba9120777000..1925989838c5c9a94da651947c9e30c73c9ae342 100644 --- a/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/device_grouped_conv2d_fwd_dl_gnhwc_gkyxc_gnhwk_f16_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/device_grouped_conv2d_fwd_dl_gnhwc_gkyxc_gnhwk_f16_instance.cpp @@ -1,8 +1,8 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" -#include "device_grouped_conv2d_fwd_dl_instance.hpp" +#include "ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_dl_instance.hpp" namespace ck { namespace tensor_operation { diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/device_grouped_conv2d_fwd_dl_gnhwc_gkyxc_gnhwk_f32_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/device_grouped_conv2d_fwd_dl_gnhwc_gkyxc_gnhwk_f32_instance.cpp index 85300b4e44e54df0317fd305060d6a0ac6adb33b..95ef4bbe37c70bd9755429b1d2362e65aa3ac748 100644 --- a/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/device_grouped_conv2d_fwd_dl_gnhwc_gkyxc_gnhwk_f32_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/device_grouped_conv2d_fwd_dl_gnhwc_gkyxc_gnhwk_f32_instance.cpp @@ -1,8 +1,8 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" -#include "device_grouped_conv2d_fwd_dl_instance.hpp" +#include "ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_dl_instance.hpp" namespace ck { namespace tensor_operation { diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/device_grouped_conv2d_fwd_dl_nhwgc_gkyxc_nhwgk_f16_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/device_grouped_conv2d_fwd_dl_nhwgc_gkyxc_nhwgk_f16_instance.cpp new file mode 100644 index 0000000000000000000000000000000000000000..4a0f7b9b5bb93ba55d136ae1313cda69e999bcf1 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/device_grouped_conv2d_fwd_dl_nhwgc_gkyxc_nhwgk_f16_instance.cpp @@ -0,0 +1,57 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" +#include "ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_dl_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +void add_device_grouped_conv2d_fwd_dl_nhwgc_gkyxc_nhwgk_f16_instances( + std::vector>>& instances) +{ + add_device_operation_instances(instances, + device_grouped_conv2d_fwd_dl_f16_instances{}); + + add_device_operation_instances(instances, + device_grouped_conv2d_fwd_dl_f16_instances{}); + + add_device_operation_instances(instances, + device_grouped_conv2d_fwd_dl_f16_instances{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/device_grouped_conv2d_fwd_dl_nhwgc_gkyxc_nhwgk_f32_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/device_grouped_conv2d_fwd_dl_nhwgc_gkyxc_nhwgk_f32_instance.cpp new file mode 100644 index 0000000000000000000000000000000000000000..94e57bef469e22804978b651d2c6d044cc17eda2 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/device_grouped_conv2d_fwd_dl_nhwgc_gkyxc_nhwgk_f32_instance.cpp @@ -0,0 +1,57 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" +#include "ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_dl_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +void add_device_grouped_conv2d_fwd_dl_nhwgc_gkyxc_nhwgk_f32_instances( + std::vector>>& instances) +{ + add_device_operation_instances(instances, + device_grouped_conv2d_fwd_dl_f32_instances{}); + + add_device_operation_instances(instances, + device_grouped_conv2d_fwd_dl_f32_instances{}); + + add_device_operation_instances(instances, + device_grouped_conv2d_fwd_dl_f32_instances{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/device_grouped_conv2d_fwd_wmma_gnhwc_gkyxc_gnhwk_f16_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/device_grouped_conv2d_fwd_wmma_gnhwc_gkyxc_gnhwk_f16_instance.cpp new file mode 100644 index 0000000000000000000000000000000000000000..3904592fca3849dd8257e665e814cc98a1394eee --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/device_grouped_conv2d_fwd_wmma_gnhwc_gkyxc_gnhwk_f16_instance.cpp @@ -0,0 +1,66 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" +#include "ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv2d_fwd_wmma_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { +// Compilation parameters for in[g, n, hi ,wi, c] * wei[g, k, y, x, c] = out[g, n, ho, wo, k] +void add_device_grouped_conv2d_fwd_wmma_gnhwc_gkyxc_gnhwk_f16_instances( + std::vector>>& instances) +{ + add_device_operation_instances(instances, + device_grouped_conv2d_fwd_wmma_f16_instances{}); + + add_device_operation_instances(instances, + device_grouped_conv2d_fwd_wmma_f16_instances{}); + + add_device_operation_instances(instances, + device_grouped_conv2d_fwd_wmma_f16_instances{}); + + add_device_operation_instances(instances, + device_grouped_conv2d_fwd_wmma_f16_instances{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/device_grouped_conv2d_fwd_wmma_gnhwc_gkyxc_gnhwk_i8_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/device_grouped_conv2d_fwd_wmma_gnhwc_gkyxc_gnhwk_i8_instance.cpp new file mode 100644 index 0000000000000000000000000000000000000000..0cded93d676c1e3adb237c447a7a956845cc4cb9 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/device_grouped_conv2d_fwd_wmma_gnhwc_gkyxc_gnhwk_i8_instance.cpp @@ -0,0 +1,66 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" +#include "ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv2d_fwd_wmma_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { +// Compilation parameters for in[g, n, hi ,wi, c] * wei[g, k, y, x, c] = out[g, n, ho, wo, k] +void add_device_grouped_conv2d_fwd_wmma_gnhwc_gkyxc_gnhwk_i8_instances( + std::vector>>& instances) +{ + add_device_operation_instances(instances, + device_grouped_conv2d_fwd_wmma_i8_instances{}); + + add_device_operation_instances(instances, + device_grouped_conv2d_fwd_wmma_i8_instances{}); + + add_device_operation_instances(instances, + device_grouped_conv2d_fwd_wmma_i8_instances{}); + + add_device_operation_instances(instances, + device_grouped_conv2d_fwd_wmma_i8_instances{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/device_grouped_conv2d_fwd_xdl_gnhwc_gkyxc_gnhwk_bf16_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/device_grouped_conv2d_fwd_xdl_gnhwc_gkyxc_gnhwk_bf16_instance.cpp index 40593a0efba9f68e96d3760950ab9cf874b6b5cb..db14ce43060d6e837a8b5ffe49941a90329d3d42 100644 --- a/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/device_grouped_conv2d_fwd_xdl_gnhwc_gkyxc_gnhwk_bf16_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/device_grouped_conv2d_fwd_xdl_gnhwc_gkyxc_gnhwk_bf16_instance.cpp @@ -1,8 +1,8 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" -#include "device_grouped_conv2d_fwd_xdl_instance.hpp" +#include "ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_instance.hpp" namespace ck { namespace tensor_operation { @@ -24,40 +24,36 @@ void add_device_grouped_conv1d_fwd_xdl_gnhwc_gkyxc_gnhwk_bf16_instances( PassThrough>>>& instances) { add_device_operation_instances(instances, - device_grouped_conv2d_fwd_xdl_bf16_instances{}); + device_grouped_conv_fwd_xdl_bf16_instances<2, + GNHWC, + GKYXC, + Empty_Tuple, + GNHWK, + ConvFwdDefault>{}); add_device_operation_instances(instances, - device_grouped_conv2d_fwd_xdl_bf16_instances{}); + device_grouped_conv_fwd_xdl_bf16_instances<2, + GNHWC, + GKYXC, + Empty_Tuple, + GNHWK, + ConvFwd1x1P0>{}); add_device_operation_instances(instances, - device_grouped_conv2d_fwd_xdl_bf16_instances{}); + device_grouped_conv_fwd_xdl_bf16_instances<2, + GNHWC, + GKYXC, + Empty_Tuple, + GNHWK, + ConvFwd1x1S1P0>{}); add_device_operation_instances(instances, - device_grouped_conv2d_fwd_xdl_bf16_instances{}); + device_grouped_conv_fwd_xdl_bf16_instances<2, + GNHWC, + GKYXC, + Empty_Tuple, + GNHWK, + ConvFwdOddC>{}); } } // namespace instance diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/device_grouped_conv2d_fwd_xdl_gnhwc_gkyxc_gnhwk_f16_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/device_grouped_conv2d_fwd_xdl_gnhwc_gkyxc_gnhwk_f16_instance.cpp index 7088028bf54ab66f3873526e098d0e16b92a30d1..debdb05b4db40df89aef26ab8b0e2f81606c4c93 100644 --- a/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/device_grouped_conv2d_fwd_xdl_gnhwc_gkyxc_gnhwk_f16_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/device_grouped_conv2d_fwd_xdl_gnhwc_gkyxc_gnhwk_f16_instance.cpp @@ -1,8 +1,8 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" -#include "device_grouped_conv2d_fwd_xdl_instance.hpp" +#include "ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_instance.hpp" namespace ck { namespace tensor_operation { @@ -24,40 +24,36 @@ void add_device_grouped_conv2d_fwd_xdl_gnhwc_gkyxc_gnhwk_f16_instances( PassThrough>>>& instances) { add_device_operation_instances(instances, - device_grouped_conv2d_fwd_xdl_f16_instances{}); + device_grouped_conv_fwd_xdl_f16_instances<2, + GNHWC, + GKYXC, + Empty_Tuple, + GNHWK, + ConvFwdDefault>{}); add_device_operation_instances(instances, - device_grouped_conv2d_fwd_xdl_f16_instances{}); + device_grouped_conv_fwd_xdl_f16_instances<2, + GNHWC, + GKYXC, + Empty_Tuple, + GNHWK, + ConvFwd1x1P0>{}); add_device_operation_instances(instances, - device_grouped_conv2d_fwd_xdl_f16_instances{}); + device_grouped_conv_fwd_xdl_f16_instances<2, + GNHWC, + GKYXC, + Empty_Tuple, + GNHWK, + ConvFwd1x1S1P0>{}); add_device_operation_instances(instances, - device_grouped_conv2d_fwd_xdl_f16_instances{}); + device_grouped_conv_fwd_xdl_f16_instances<2, + GNHWC, + GKYXC, + Empty_Tuple, + GNHWK, + ConvFwdOddC>{}); } } // namespace instance diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/device_grouped_conv2d_fwd_xdl_gnhwc_gkyxc_gnhwk_f32_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/device_grouped_conv2d_fwd_xdl_gnhwc_gkyxc_gnhwk_f32_instance.cpp index 919274c5033aa1e993e4bd0c094dd9a48cc2438c..20326f5be1eb41268aab93a7db1469ec8a1957db 100644 --- a/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/device_grouped_conv2d_fwd_xdl_gnhwc_gkyxc_gnhwk_f32_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/device_grouped_conv2d_fwd_xdl_gnhwc_gkyxc_gnhwk_f32_instance.cpp @@ -1,8 +1,8 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" -#include "device_grouped_conv2d_fwd_xdl_instance.hpp" +#include "ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_instance.hpp" namespace ck { namespace tensor_operation { @@ -24,40 +24,36 @@ void add_device_grouped_conv2d_fwd_xdl_gnhwc_gkyxc_gnhwk_f32_instances( PassThrough>>>& instances) { add_device_operation_instances(instances, - device_grouped_conv2d_fwd_xdl_f32_instances{}); + device_grouped_conv_fwd_xdl_f32_instances<2, + GNHWC, + GKYXC, + Empty_Tuple, + GNHWK, + ConvFwdDefault>{}); add_device_operation_instances(instances, - device_grouped_conv2d_fwd_xdl_f32_instances{}); + device_grouped_conv_fwd_xdl_f32_instances<2, + GNHWC, + GKYXC, + Empty_Tuple, + GNHWK, + ConvFwd1x1P0>{}); add_device_operation_instances(instances, - device_grouped_conv2d_fwd_xdl_f32_instances{}); + device_grouped_conv_fwd_xdl_f32_instances<2, + GNHWC, + GKYXC, + Empty_Tuple, + GNHWK, + ConvFwd1x1S1P0>{}); add_device_operation_instances(instances, - device_grouped_conv2d_fwd_xdl_f32_instances{}); + device_grouped_conv_fwd_xdl_f32_instances<2, + GNHWC, + GKYXC, + Empty_Tuple, + GNHWK, + ConvFwdOddC>{}); } } // namespace instance diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/device_grouped_conv2d_fwd_xdl_instance.hpp b/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/device_grouped_conv2d_fwd_xdl_instance.hpp deleted file mode 100644 index 2858671ee90005cd807c7f8dbb42e3823f04cb54..0000000000000000000000000000000000000000 --- a/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/device_grouped_conv2d_fwd_xdl_instance.hpp +++ /dev/null @@ -1,105 +0,0 @@ -// SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. - -#include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_d_xdl_cshuffle.hpp" -#include "device_grouped_conv2d_fwd_common.hpp" - -namespace ck { -namespace tensor_operation { -namespace device { -namespace instance { - -template -using device_grouped_conv2d_fwd_xdl_f16_instances = - std::tuple< - // clang-format off - //########################################| NumDim| A| B| Ds| E| AData| BData| AccData| CShuffle| Ds| EData| A| B| CDE| ConvForward| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| - //########################################| Spatial| Layout| Layout| Layout| Layout| Type| Type| Type| DataType| DataType| Type| Elementwise| Elementwise| Elementwise| Specialization| Specialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| - //########################################| | | | | | | | | | | | Operation| Operation| Operation| | | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| - //########################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | - DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 2, ALayout, BLayout, DsLayout, ELayout, F16, F16, F32, F16, DsDatatype, F16, PassThrough, PassThrough, CDEElementOp, ConvSpec, GemmMNKPadding, 1, 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, - DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 2, ALayout, BLayout, DsLayout, ELayout, F16, F16, F32, F16, DsDatatype, F16, PassThrough, PassThrough, CDEElementOp, ConvSpec, GemmMNKPadding, 1, 256, 128, 256, 32, 8, 8, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, - DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 2, ALayout, BLayout, DsLayout, ELayout, F16, F16, F32, F16, DsDatatype, F16, PassThrough, PassThrough, CDEElementOp, ConvSpec, GemmMNKPadding, 1, 128, 128, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 8>, - DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 2, ALayout, BLayout, DsLayout, ELayout, F16, F16, F32, F16, DsDatatype, F16, PassThrough, PassThrough, CDEElementOp, ConvSpec, GemmMNKPadding, 1, 256, 128, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, - DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 2, ALayout, BLayout, DsLayout, ELayout, F16, F16, F32, F16, DsDatatype, F16, PassThrough, PassThrough, CDEElementOp, ConvSpec, GemmMNKPadding, 1, 128, 128, 64, 32, 8, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 8>, - DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 2, ALayout, BLayout, DsLayout, ELayout, F16, F16, F32, F16, DsDatatype, F16, PassThrough, PassThrough, CDEElementOp, ConvSpec, GemmMNKPadding, 1, 128, 64, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 8>, - DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 2, ALayout, BLayout, DsLayout, ELayout, F16, F16, F32, F16, DsDatatype, F16, PassThrough, PassThrough, CDEElementOp, ConvSpec, GemmMNKPadding, 1, 64, 64, 64, 32, 8, 8, 32, 32, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 8>, - DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 2, ALayout, BLayout, DsLayout, ELayout, F16, F16, F32, F16, DsDatatype, F16, PassThrough, PassThrough, CDEElementOp, ConvSpec, GemmMNKPadding, 1, 256, 128, 64, 32, 8, 8, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, - DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 2, ALayout, BLayout, DsLayout, ELayout, F16, F16, F32, F16, DsDatatype, F16, PassThrough, PassThrough, CDEElementOp, ConvSpec, GemmMNKPadding, 1, 256, 64, 128, 32, 8, 8, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, - DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 2, ALayout, BLayout, DsLayout, ELayout, F16, F16, F32, F16, DsDatatype, F16, PassThrough, PassThrough, CDEElementOp, ConvSpec, GemmMNKPadding, 1, 128, 128, 32, 32, 8, 8, 32, 32, 2, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 8>, - DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 2, ALayout, BLayout, DsLayout, ELayout, F16, F16, F32, F16, DsDatatype, F16, PassThrough, PassThrough, CDEElementOp, ConvSpec, GemmMNKPadding, 1, 128, 32, 128, 32, 8, 8, 32, 32, 1, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 8>, - DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 2, ALayout, BLayout, DsLayout, ELayout, F16, F16, F32, F16, DsDatatype, F16, PassThrough, PassThrough, CDEElementOp, ConvSpec, GemmMNKPadding, 1, 64, 64, 32, 32, 8, 8, 32, 32, 2, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 8>, - DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 2, ALayout, BLayout, DsLayout, ELayout, F16, F16, F32, F16, DsDatatype, F16, PassThrough, PassThrough, CDEElementOp, ConvSpec, GemmMNKPadding, 1, 64, 32, 64, 32, 8, 8, 32, 32, 1, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 8> - // clang-format on - >; - -template -using device_grouped_conv2d_fwd_xdl_bf16_instances = - std::tuple< - // clang-format off - //########################################| NumDim| A| B| Ds| E| AData| BData| AccData| CShuffle| Ds| EData| A| B| CDE| ConvForward| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| - //########################################| Spatial| Layout| Layout| Layout| Layout| Type| Type| Type| DataType| DataType| Type| Elementwise| Elementwise| Elementwise| Specialization| Specialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| - //########################################| | | | | | | | | | | | Operation| Operation| Operation| | | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| - //########################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | - DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 2, ALayout, BLayout, DsLayout, ELayout, BF16, BF16, F32, BF16, DsDatatype, BF16, PassThrough, PassThrough, CDEElementOp, ConvSpec, GemmMNKPadding, 1, 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, - DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 2, ALayout, BLayout, DsLayout, ELayout, BF16, BF16, F32, BF16, DsDatatype, BF16, PassThrough, PassThrough, CDEElementOp, ConvSpec, GemmMNKPadding, 1, 256, 128, 256, 32, 8, 8, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, - DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 2, ALayout, BLayout, DsLayout, ELayout, BF16, BF16, F32, BF16, DsDatatype, BF16, PassThrough, PassThrough, CDEElementOp, ConvSpec, GemmMNKPadding, 1, 128, 128, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 8>, - DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 2, ALayout, BLayout, DsLayout, ELayout, BF16, BF16, F32, BF16, DsDatatype, BF16, PassThrough, PassThrough, CDEElementOp, ConvSpec, GemmMNKPadding, 1, 256, 128, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, - DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 2, ALayout, BLayout, DsLayout, ELayout, BF16, BF16, F32, BF16, DsDatatype, BF16, PassThrough, PassThrough, CDEElementOp, ConvSpec, GemmMNKPadding, 1, 128, 128, 64, 32, 8, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 8>, - DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 2, ALayout, BLayout, DsLayout, ELayout, BF16, BF16, F32, BF16, DsDatatype, BF16, PassThrough, PassThrough, CDEElementOp, ConvSpec, GemmMNKPadding, 1, 128, 64, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 8>, - DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 2, ALayout, BLayout, DsLayout, ELayout, BF16, BF16, F32, BF16, DsDatatype, BF16, PassThrough, PassThrough, CDEElementOp, ConvSpec, GemmMNKPadding, 1, 64, 64, 64, 32, 8, 8, 32, 32, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 8>, - DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 2, ALayout, BLayout, DsLayout, ELayout, BF16, BF16, F32, BF16, DsDatatype, BF16, PassThrough, PassThrough, CDEElementOp, ConvSpec, GemmMNKPadding, 1, 256, 128, 64, 32, 8, 8, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, - DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 2, ALayout, BLayout, DsLayout, ELayout, BF16, BF16, F32, BF16, DsDatatype, BF16, PassThrough, PassThrough, CDEElementOp, ConvSpec, GemmMNKPadding, 1, 256, 64, 128, 32, 8, 8, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, - DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 2, ALayout, BLayout, DsLayout, ELayout, BF16, BF16, F32, BF16, DsDatatype, BF16, PassThrough, PassThrough, CDEElementOp, ConvSpec, GemmMNKPadding, 1, 128, 128, 32, 32, 8, 8, 32, 32, 2, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 8>, - DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 2, ALayout, BLayout, DsLayout, ELayout, BF16, BF16, F32, BF16, DsDatatype, BF16, PassThrough, PassThrough, CDEElementOp, ConvSpec, GemmMNKPadding, 1, 128, 32, 128, 32, 8, 8, 32, 32, 1, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 8>, - DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 2, ALayout, BLayout, DsLayout, ELayout, BF16, BF16, F32, BF16, DsDatatype, BF16, PassThrough, PassThrough, CDEElementOp, ConvSpec, GemmMNKPadding, 1, 64, 64, 32, 32, 8, 8, 32, 32, 2, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 8>, - DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 2, ALayout, BLayout, DsLayout, ELayout, BF16, BF16, F32, BF16, DsDatatype, BF16, PassThrough, PassThrough, CDEElementOp, ConvSpec, GemmMNKPadding, 1, 64, 32, 64, 32, 8, 8, 32, 32, 1, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 8> - // clang-format on - >; - -template -using device_grouped_conv2d_fwd_xdl_f32_instances = - std::tuple< - // clang-format off - //########################################| NumDim| A| B| Ds| E| AData| BData| AccData| CShuffle| Ds| EData| A| B| CDE| ConvForward| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| - //########################################| Spatial| Layout| Layout| Layout| Layout| Type| Type| Type| DataType| DataType| Type| Elementwise| Elementwise| Elementwise| Specialization| Specialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| - //########################################| | | | | | | | | | | | Operation| Operation| Operation| | | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| - //########################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | - DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 2, ALayout, BLayout, DsLayout, ELayout, F32, F32, F32, F32, DsDatatype, F32, PassThrough, PassThrough, CDEElementOp, ConvSpec, GemmMNKPadding, 1, 256, 256, 128, 16, 4, 4, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 16, 1, 16>, 4>, - DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 2, ALayout, BLayout, DsLayout, ELayout, F32, F32, F32, F32, DsDatatype, F32, PassThrough, PassThrough, CDEElementOp, ConvSpec, GemmMNKPadding, 1, 256, 128, 256, 16, 4, 4, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 16, 1, 16>, 4>, - DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 2, ALayout, BLayout, DsLayout, ELayout, F32, F32, F32, F32, DsDatatype, F32, PassThrough, PassThrough, CDEElementOp, ConvSpec, GemmMNKPadding, 1, 128, 128, 128, 16, 4, 4, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 8, 1, 16>, 4>, - DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 2, ALayout, BLayout, DsLayout, ELayout, F32, F32, F32, F32, DsDatatype, F32, PassThrough, PassThrough, CDEElementOp, ConvSpec, GemmMNKPadding, 1, 256, 128, 128, 16, 4, 4, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 16, 1, 16>, 4>, - DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 2, ALayout, BLayout, DsLayout, ELayout, F32, F32, F32, F32, DsDatatype, F32, PassThrough, PassThrough, CDEElementOp, ConvSpec, GemmMNKPadding, 1, 128, 128, 64, 16, 4, 4, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 16, 1, 8>, 4>, - DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 2, ALayout, BLayout, DsLayout, ELayout, F32, F32, F32, F32, DsDatatype, F32, PassThrough, PassThrough, CDEElementOp, ConvSpec, GemmMNKPadding, 1, 128, 64, 128, 16, 4, 4, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 8, 1, 16>, 4>, - DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 2, ALayout, BLayout, DsLayout, ELayout, F32, F32, F32, F32, DsDatatype, F32, PassThrough, PassThrough, CDEElementOp, ConvSpec, GemmMNKPadding, 1, 64, 64, 64, 16, 4, 4, 32, 32, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 8, 1, 8>, 4>, - DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 2, ALayout, BLayout, DsLayout, ELayout, F32, F32, F32, F32, DsDatatype, F32, PassThrough, PassThrough, CDEElementOp, ConvSpec, GemmMNKPadding, 1, 256, 128, 64, 16, 4, 4, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 16, 1, 16>, 4>, - DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 2, ALayout, BLayout, DsLayout, ELayout, F32, F32, F32, F32, DsDatatype, F32, PassThrough, PassThrough, CDEElementOp, ConvSpec, GemmMNKPadding, 1, 256, 64, 128, 16, 4, 4, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 16, 1, 16>, 4>, - DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 2, ALayout, BLayout, DsLayout, ELayout, F32, F32, F32, F32, DsDatatype, F32, PassThrough, PassThrough, CDEElementOp, ConvSpec, GemmMNKPadding, 1, 128, 128, 32, 16, 4, 4, 32, 32, 2, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 16, 1, 8>, 4>, - DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 2, ALayout, BLayout, DsLayout, ELayout, F32, F32, F32, F32, DsDatatype, F32, PassThrough, PassThrough, CDEElementOp, ConvSpec, GemmMNKPadding, 1, 128, 32, 128, 16, 4, 4, 32, 32, 1, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 8, 1, 16>, 4>, - DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 2, ALayout, BLayout, DsLayout, ELayout, F32, F32, F32, F32, DsDatatype, F32, PassThrough, PassThrough, CDEElementOp, ConvSpec, GemmMNKPadding, 1, 64, 64, 32, 16, 4, 4, 32, 32, 2, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 8, 1, 8>, 4>, - DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 2, ALayout, BLayout, DsLayout, ELayout, F32, F32, F32, F32, DsDatatype, F32, PassThrough, PassThrough, CDEElementOp, ConvSpec, GemmMNKPadding, 1, 64, 32, 64, 16, 4, 4, 32, 32, 1, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 8, 1, 8>, 4> - // clang-format on - >; - -} // namespace instance -} // namespace device -} // namespace tensor_operation -} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_bf16_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_bf16_instance.cpp index 25caf61df1349638f8f089acb69d2221a82be7fa..af68c3b07dc32f259b94b03dfb5f51b9084b84ea 100644 --- a/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_bf16_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_bf16_instance.cpp @@ -1,8 +1,8 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" -#include "device_grouped_conv2d_fwd_xdl_instance.hpp" +#include "ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_instance.hpp" namespace ck { namespace tensor_operation { @@ -24,40 +24,36 @@ void add_device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_bf16_instances( PassThrough>>>& instances) { add_device_operation_instances(instances, - device_grouped_conv2d_fwd_xdl_bf16_instances{}); + device_grouped_conv_fwd_xdl_bf16_instances<2, + NHWGC, + GKYXC, + Empty_Tuple, + NHWGK, + ConvFwdDefault>{}); add_device_operation_instances(instances, - device_grouped_conv2d_fwd_xdl_bf16_instances{}); + device_grouped_conv_fwd_xdl_bf16_instances<2, + NHWGC, + GKYXC, + Empty_Tuple, + NHWGK, + ConvFwd1x1P0>{}); add_device_operation_instances(instances, - device_grouped_conv2d_fwd_xdl_bf16_instances{}); + device_grouped_conv_fwd_xdl_bf16_instances<2, + NHWGC, + GKYXC, + Empty_Tuple, + NHWGK, + ConvFwd1x1S1P0>{}); add_device_operation_instances(instances, - device_grouped_conv2d_fwd_xdl_bf16_instances{}); + device_grouped_conv_fwd_xdl_bf16_instances<2, + NHWGC, + GKYXC, + Empty_Tuple, + NHWGK, + ConvFwdOddC>{}); } } // namespace instance diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_f16_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_f16_instance.cpp index b997cfb6723c2699136b2eea0d90f1afce22ee6f..8b1506e0fba50c1abb0c55915f1c583fba39acf3 100644 --- a/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_f16_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_f16_instance.cpp @@ -1,8 +1,8 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" -#include "device_grouped_conv2d_fwd_xdl_instance.hpp" +#include "ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_instance.hpp" namespace ck { namespace tensor_operation { @@ -24,40 +24,36 @@ void add_device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_f16_instances( PassThrough>>>& instances) { add_device_operation_instances(instances, - device_grouped_conv2d_fwd_xdl_f16_instances{}); + device_grouped_conv_fwd_xdl_f16_instances<2, + NHWGC, + GKYXC, + Empty_Tuple, + NHWGK, + ConvFwdDefault>{}); add_device_operation_instances(instances, - device_grouped_conv2d_fwd_xdl_f16_instances{}); + device_grouped_conv_fwd_xdl_f16_instances<2, + NHWGC, + GKYXC, + Empty_Tuple, + NHWGK, + ConvFwd1x1P0>{}); add_device_operation_instances(instances, - device_grouped_conv2d_fwd_xdl_f16_instances{}); + device_grouped_conv_fwd_xdl_f16_instances<2, + NHWGC, + GKYXC, + Empty_Tuple, + NHWGK, + ConvFwd1x1S1P0>{}); add_device_operation_instances(instances, - device_grouped_conv2d_fwd_xdl_f16_instances{}); + device_grouped_conv_fwd_xdl_f16_instances<2, + NHWGC, + GKYXC, + Empty_Tuple, + NHWGK, + ConvFwdOddC>{}); } } // namespace instance diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_f32_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_f32_instance.cpp index 3256a2a8268cf02bd4f8f2ec2d71bfa77c74ff21..c8bdfb8c7ce0444a988be1fa3576677ec373a279 100644 --- a/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_f32_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_f32_instance.cpp @@ -1,8 +1,8 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" -#include "device_grouped_conv2d_fwd_xdl_instance.hpp" +#include "ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_instance.hpp" namespace ck { namespace tensor_operation { @@ -24,40 +24,36 @@ void add_device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_f32_instances( PassThrough>>>& instances) { add_device_operation_instances(instances, - device_grouped_conv2d_fwd_xdl_f32_instances{}); + device_grouped_conv_fwd_xdl_f32_instances<2, + NHWGC, + GKYXC, + Empty_Tuple, + NHWGK, + ConvFwdDefault>{}); add_device_operation_instances(instances, - device_grouped_conv2d_fwd_xdl_f32_instances{}); + device_grouped_conv_fwd_xdl_f32_instances<2, + NHWGC, + GKYXC, + Empty_Tuple, + NHWGK, + ConvFwd1x1P0>{}); add_device_operation_instances(instances, - device_grouped_conv2d_fwd_xdl_f32_instances{}); + device_grouped_conv_fwd_xdl_f32_instances<2, + NHWGC, + GKYXC, + Empty_Tuple, + NHWGK, + ConvFwd1x1S1P0>{}); add_device_operation_instances(instances, - device_grouped_conv2d_fwd_xdl_f32_instances{}); + device_grouped_conv_fwd_xdl_f32_instances<2, + NHWGC, + GKYXC, + Empty_Tuple, + NHWGK, + ConvFwdOddC>{}); } } // namespace instance diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_data/CMakeLists.txt b/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_data/CMakeLists.txt new file mode 100644 index 0000000000000000000000000000000000000000..8383d622b2832aa2dba17a92944c43d495734184 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_data/CMakeLists.txt @@ -0,0 +1,8 @@ +add_instance_library(device_grouped_conv3d_bwd_data_instance + device_grouped_conv3d_bwd_data_xdl_gndhwc_gkzyxc_gndhwk_f16_instance.cpp + device_grouped_conv3d_bwd_data_xdl_gndhwc_gkzyxc_gndhwk_bf16_instance.cpp + device_grouped_conv3d_bwd_data_xdl_gndhwc_gkzyxc_gndhwk_f32_instance.cpp + device_grouped_conv3d_bwd_data_xdl_ndhwgc_gkzyxc_ndhwgk_f16_instance.cpp + device_grouped_conv3d_bwd_data_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_instance.cpp + device_grouped_conv3d_bwd_data_xdl_ndhwgc_gkzyxc_ndhwgk_f32_instance.cpp +) diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_data/device_grouped_conv3d_bwd_data_xdl_gndhwc_gkzyxc_gndhwk_bf16_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_data/device_grouped_conv3d_bwd_data_xdl_gndhwc_gkzyxc_gndhwk_bf16_instance.cpp new file mode 100644 index 0000000000000000000000000000000000000000..8331ea1fda47c37d317e1803142ee0b044e9ed8a --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_data/device_grouped_conv3d_bwd_data_xdl_gndhwc_gkzyxc_gndhwk_bf16_instance.cpp @@ -0,0 +1,50 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" +#include "ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_data/device_grouped_conv_bwd_data_xdl_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { +// Compilation parameters for out[g, n, di, hi, wi, c] * wei[g, k, z, y, x, c] = in[g, n, do, ho, +// wo, k] +void add_device_grouped_conv3d_bwd_data_xdl_gndhwk_gkzyxc_gndhwc_bf16_instances( + std::vector>>& instances) +{ + // 1. Default + add_device_operation_instances( + instances, + device_grouped_conv_bwd_data_xdl_bf16_instances<3, + GNDHWK, + GKZYXC, + Empty_Tuple, + GNDHWC, + ConvBwdDataDefault>{}); + // 2. Filter1x1Stride1Pad0 + add_device_operation_instances( + instances, + device_grouped_conv_bwd_data_xdl_bf16_instances<3, + GNDHWK, + GKZYXC, + Empty_Tuple, + GNDHWC, + ConvBwdDataFilter1x1Stride1Pad0>{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_data/device_grouped_conv3d_bwd_data_xdl_gndhwc_gkzyxc_gndhwk_f16_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_data/device_grouped_conv3d_bwd_data_xdl_gndhwc_gkzyxc_gndhwk_f16_instance.cpp new file mode 100644 index 0000000000000000000000000000000000000000..1885d49c812f7fc05bd11f52c188a64672275113 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_data/device_grouped_conv3d_bwd_data_xdl_gndhwc_gkzyxc_gndhwk_f16_instance.cpp @@ -0,0 +1,50 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" +#include "ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_data/device_grouped_conv_bwd_data_xdl_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { +// Compilation parameters for out[g, n, di, hi, wi, c] * wei[g, k, z, y, x, c] = in[g, n, do, ho, +// wo, k] +void add_device_grouped_conv3d_bwd_data_xdl_gndhwk_gkzyxc_gndhwc_f16_instances( + std::vector>>& instances) +{ + // 1. Default + add_device_operation_instances( + instances, + device_grouped_conv_bwd_data_xdl_f16_instances<3, + GNDHWK, + GKZYXC, + Empty_Tuple, + GNDHWC, + ConvBwdDataDefault>{}); + // 2. Filter1x1Stride1Pad0 + add_device_operation_instances( + instances, + device_grouped_conv_bwd_data_xdl_f16_instances<3, + GNDHWK, + GKZYXC, + Empty_Tuple, + GNDHWC, + ConvBwdDataFilter1x1Stride1Pad0>{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_data/device_grouped_conv3d_bwd_data_xdl_gndhwc_gkzyxc_gndhwk_f32_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_data/device_grouped_conv3d_bwd_data_xdl_gndhwc_gkzyxc_gndhwk_f32_instance.cpp new file mode 100644 index 0000000000000000000000000000000000000000..77135fcc05279cb019a5f61925b2d3f86d9bf218 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_data/device_grouped_conv3d_bwd_data_xdl_gndhwc_gkzyxc_gndhwk_f32_instance.cpp @@ -0,0 +1,50 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" +#include "ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_data/device_grouped_conv_bwd_data_xdl_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { +// Compilation parameters for out[g, n, di, hi, wi, c] * wei[g, k, z, y, x, c] = in[g, n, do, ho, +// wo, k] +void add_device_grouped_conv3d_bwd_data_xdl_gndhwk_gkzyxc_gndhwc_f32_instances( + std::vector>>& instances) +{ + // 1. Default + add_device_operation_instances( + instances, + device_grouped_conv_bwd_data_xdl_f32_instances<3, + GNDHWK, + GKZYXC, + Empty_Tuple, + GNDHWC, + ConvBwdDataDefault>{}); + // 2. Filter1x1Stride1Pad0 + add_device_operation_instances( + instances, + device_grouped_conv_bwd_data_xdl_f32_instances<3, + GNDHWK, + GKZYXC, + Empty_Tuple, + GNDHWC, + ConvBwdDataFilter1x1Stride1Pad0>{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_data/device_grouped_conv3d_bwd_data_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_data/device_grouped_conv3d_bwd_data_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_instance.cpp new file mode 100644 index 0000000000000000000000000000000000000000..663d41fe0ba5edfc197e46f5280bf86a6c92601e --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_data/device_grouped_conv3d_bwd_data_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_instance.cpp @@ -0,0 +1,50 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" +#include "ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_data/device_grouped_conv_bwd_data_xdl_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { +// Compilation parameters for out[n, di, hi, wi, g, c] * wei[g, k, z, y, x, c] = in[n, do, ho, wo, +// g, k] +void add_device_grouped_conv3d_bwd_data_xdl_ndhwgk_gkzyxc_ndhwgc_bf16_instances( + std::vector>>& instances) +{ + // 1. Default + add_device_operation_instances( + instances, + device_grouped_conv_bwd_data_xdl_bf16_instances<3, + NDHWGK, + GKZYXC, + Empty_Tuple, + NDHWGC, + ConvBwdDataDefault>{}); + // 2. Filter1x1Stride1Pad0 + add_device_operation_instances( + instances, + device_grouped_conv_bwd_data_xdl_bf16_instances<3, + NDHWGK, + GKZYXC, + Empty_Tuple, + NDHWGC, + ConvBwdDataFilter1x1Stride1Pad0>{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_data/device_grouped_conv3d_bwd_data_xdl_ndhwgc_gkzyxc_ndhwgk_f16_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_data/device_grouped_conv3d_bwd_data_xdl_ndhwgc_gkzyxc_ndhwgk_f16_instance.cpp new file mode 100644 index 0000000000000000000000000000000000000000..ac0ab44ce36b72644e4747e0f426d1b35e8019f8 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_data/device_grouped_conv3d_bwd_data_xdl_ndhwgc_gkzyxc_ndhwgk_f16_instance.cpp @@ -0,0 +1,50 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" +#include "ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_data/device_grouped_conv_bwd_data_xdl_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { +// Compilation parameters for out[n, di, hi, wi, g, c] * wei[g, k, z, y, x, c] = in[n, do, ho, wo, +// g, k] +void add_device_grouped_conv3d_bwd_data_xdl_ndhwgk_gkzyxc_ndhwgc_f16_instances( + std::vector>>& instances) +{ + // 1. Default + add_device_operation_instances( + instances, + device_grouped_conv_bwd_data_xdl_f16_instances<3, + NDHWGK, + GKZYXC, + Empty_Tuple, + NDHWGC, + ConvBwdDataDefault>{}); + // 2. Filter1x1Stride1Pad0 + add_device_operation_instances( + instances, + device_grouped_conv_bwd_data_xdl_f16_instances<3, + NDHWGK, + GKZYXC, + Empty_Tuple, + NDHWGC, + ConvBwdDataFilter1x1Stride1Pad0>{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_data/device_grouped_conv3d_bwd_data_xdl_ndhwgc_gkzyxc_ndhwgk_f32_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_data/device_grouped_conv3d_bwd_data_xdl_ndhwgc_gkzyxc_ndhwgk_f32_instance.cpp new file mode 100644 index 0000000000000000000000000000000000000000..50d5cce73d00305bee4c3f5016d5e79970333a4d --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_data/device_grouped_conv3d_bwd_data_xdl_ndhwgc_gkzyxc_ndhwgk_f32_instance.cpp @@ -0,0 +1,50 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" +#include "ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_data/device_grouped_conv_bwd_data_xdl_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { +// Compilation parameters for out[n, di, hi, wi, g, c] * wei[g, k, z, y, x, c] = in[n, do, ho, wo, +// g, k] +void add_device_grouped_conv3d_bwd_data_xdl_ndhwgk_gkzyxc_ndhwgc_f32_instances( + std::vector>>& instances) +{ + // 1. Default + add_device_operation_instances( + instances, + device_grouped_conv_bwd_data_xdl_f32_instances<3, + NDHWGK, + GKZYXC, + Empty_Tuple, + NDHWGC, + ConvBwdDataDefault>{}); + // 2. Filter1x1Stride1Pad0 + add_device_operation_instances( + instances, + device_grouped_conv_bwd_data_xdl_f32_instances<3, + NDHWGK, + GKZYXC, + Empty_Tuple, + NDHWGC, + ConvBwdDataFilter1x1Stride1Pad0>{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_weight/CMakeLists.txt b/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_weight/CMakeLists.txt index 04cad43e75e6909e32cedc925140519708babe42..5118599b4fa6b12ffb4a8917b8ec705205b89855 100644 --- a/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_weight/CMakeLists.txt +++ b/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_weight/CMakeLists.txt @@ -2,4 +2,7 @@ add_instance_library(device_grouped_conv3d_bwd_weight_instance device_grouped_conv3d_bwd_weight_xdl_gndhwc_gkzyxc_gndhwk_f16_instance.cpp device_grouped_conv3d_bwd_weight_xdl_gndhwc_gkzyxc_gndhwk_f32_instance.cpp device_grouped_conv3d_bwd_weight_xdl_gndhwc_gkzyxc_gndhwk_bf16_instance.cpp + device_grouped_conv3d_bwd_weight_xdl_ndhwgc_gkzyxc_ndhwgk_f16_instance.cpp + device_grouped_conv3d_bwd_weight_xdl_ndhwgc_gkzyxc_ndhwgk_f32_instance.cpp + device_grouped_conv3d_bwd_weight_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_instance.cpp ) diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_weight/device_grouped_conv3d_bwd_weight_xdl_gndhwc_gkzyxc_gndhwk_bf16_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_weight/device_grouped_conv3d_bwd_weight_xdl_gndhwc_gkzyxc_gndhwk_bf16_instance.cpp index e48db4a5314528bb45f4444f93da4ba8872e86d8..c8f456db836ef3cc43198188edb30b0676bf0046 100644 --- a/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_weight/device_grouped_conv3d_bwd_weight_xdl_gndhwc_gkzyxc_gndhwk_bf16_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_weight/device_grouped_conv3d_bwd_weight_xdl_gndhwc_gkzyxc_gndhwk_bf16_instance.cpp @@ -1,85 +1,21 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #include #include "ck/ck.hpp" #include "ck/tensor_operation/gpu/device/tensor_layout.hpp" -#include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_gnwc_gkxc_gnwk_xdl_cshuffle.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_xdl_cshuffle.hpp" #include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" #include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" +#include "ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_weight/device_grouped_conv_bwd_weight_xdl_instance.hpp" namespace ck { namespace tensor_operation { namespace device { namespace instance { -using BF16 = bhalf_t; -using F32 = float; - -template -using S = ck::Sequence; - -using GNDHWC = ck::tensor_layout::convolution::GNDHWC; -using GKZYXC = ck::tensor_layout::convolution::GKZYXC; -using GNDHWK = ck::tensor_layout::convolution::GNDHWK; - -using PassThrough = ck::tensor_operation::element_wise::PassThrough; - -static constexpr auto ConvBwdWeightDefault = - ck::tensor_operation::device::ConvolutionBackwardWeightSpecialization::Default; - -static constexpr auto ConvBwdWeightFilter1x1Stride1Pad0 = - ck::tensor_operation::device::ConvolutionBackwardWeightSpecialization::Filter1x1Stride1Pad0; - -// Compilation parameters for in[n, di, hi, wi, c] * wei[k, z, y, x, c] = out[n, do, ho, wo, k] -using device_grouped_conv3d_bwd_weight_xdl_c_shuffle_gndhwc_gkzyxc_gndhwk_bf16_f32_bf16_instances = - std::tuple< - // clang-format off - //#########################################| Num| InData| WeiData| OutData| AccData| In| Wei| Out| ConvBackward| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransfer| CBlockTransfer| - //#########################################| Dim| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise| Weight| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| ClusterLengths| ScalarPerVector| - //#########################################| Spatial| | | | | Operation| Operation| Operation| Specialization| | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| MBlock_MPerBlock| NWaveNPerXdl| - //#########################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | NBlock_NPerBlock| | - DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Xdl_CShuffle< 3, BF16, F32, BF16, F32, PassThrough, PassThrough, PassThrough, ConvBwdWeightDefault, 256, 256, 128, 4, 8, 32, 32, 4, 2, S<1, 4, 32, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, S<1, 4, 16, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 2, true, 1, 1, S<1, 32, 1, 8>, 4>, - DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Xdl_CShuffle< 3, BF16, F32, BF16, F32, PassThrough, PassThrough, PassThrough, ConvBwdWeightDefault, 256, 128, 256, 4, 8, 32, 32, 2, 4, S<1, 4, 16, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 2, true, S<1, 4, 32, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, 1, 1, S<1, 32, 1, 8>, 4>, - DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Xdl_CShuffle< 3, BF16, F32, BF16, F32, PassThrough, PassThrough, PassThrough, ConvBwdWeightDefault, 128, 128, 128, 4, 8, 32, 32, 4, 2, S<1, 4, 16, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, S<1, 4, 16, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, 1, 1, S<1, 32, 1, 4>, 4>, - DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Xdl_CShuffle< 3, BF16, F32, BF16, F32, PassThrough, PassThrough, PassThrough, ConvBwdWeightDefault, 256, 128, 128, 4, 8, 32, 32, 2, 2, S<1, 4, 16, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 2, true, S<1, 4, 16, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 2, true, 1, 1, S<1, 32, 1, 4>, 4>, - DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Xdl_CShuffle< 3, BF16, F32, BF16, F32, PassThrough, PassThrough, PassThrough, ConvBwdWeightDefault, 128, 128, 64, 4, 8, 32, 32, 2, 2, S<1, 4, 16, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, S<1, 4, 8, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 2, true, 1, 1, S<1, 32, 1, 4>, 4>, - DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Xdl_CShuffle< 3, BF16, F32, BF16, F32, PassThrough, PassThrough, PassThrough, ConvBwdWeightDefault, 128, 64, 128, 4, 8, 32, 32, 2, 2, S<1, 4, 8, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 2, true, S<1, 4, 16, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, 1, 1, S<1, 32, 1, 4>, 4>, - DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Xdl_CShuffle< 3, BF16, F32, BF16, F32, PassThrough, PassThrough, PassThrough, ConvBwdWeightDefault, 64, 64, 64, 4, 8, 32, 32, 2, 2, S<1, 4, 8, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, S<1, 4, 8, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, 1, 1, S<1, 16, 1, 4>, 4>, - DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Xdl_CShuffle< 3, BF16, F32, BF16, F32, PassThrough, PassThrough, PassThrough, ConvBwdWeightDefault, 256, 128, 64, 4, 8, 32, 32, 2, 1, S<1, 4, 16, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 2, true, S<1, 4, 8, 8>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 1, true, 1, 1, S<1, 32, 1, 4>, 4>, - DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Xdl_CShuffle< 3, BF16, F32, BF16, F32, PassThrough, PassThrough, PassThrough, ConvBwdWeightDefault, 256, 64, 128, 4, 8, 32, 32, 1, 2, S<1, 4, 8, 8>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 1, true, S<1, 4, 16, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 2, true, 1, 1, S<1, 32, 1, 4>, 4>, - DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Xdl_CShuffle< 3, BF16, F32, BF16, F32, PassThrough, PassThrough, PassThrough, ConvBwdWeightDefault, 128, 128, 32, 4, 8, 32, 32, 2, 1, S<1, 4, 16, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, S<1, 4, 4, 8>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 1, true, 1, 1, S<1, 32, 1, 4>, 4>, - DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Xdl_CShuffle< 3, BF16, F32, BF16, F32, PassThrough, PassThrough, PassThrough, ConvBwdWeightDefault, 128, 32, 128, 4, 8, 32, 32, 1, 2, S<1, 4, 4, 8>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 1, true, S<1, 4, 16, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, 1, 1, S<1, 32, 1, 4>, 4>, - DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Xdl_CShuffle< 3, BF16, F32, BF16, F32, PassThrough, PassThrough, PassThrough, ConvBwdWeightDefault, 64, 64, 32, 4, 8, 32, 32, 2, 1, S<1, 4, 8, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, S<1, 4, 4, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 2, true, 1, 1, S<1, 16, 1, 4>, 4>, - DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Xdl_CShuffle< 3, BF16, F32, BF16, F32, PassThrough, PassThrough, PassThrough, ConvBwdWeightDefault, 64, 32, 64, 4, 8, 32, 32, 1, 2, S<1, 4, 4, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 2, true, S<1, 4, 8, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, 1, 1, S<1, 16, 1, 4>, 4> - // clang-format on - >; - -using device_grouped_conv3d_bwd_weight_xdl_gndhwc_gkzyxc_gndhwk_1x1_s1_p0_bf16_f32_bf16_instances = - std::tuple< - // clang-format off - //#########################################| Num| InData| WeiData| OutData| AccData| In| Wei| Out| ConvBackward| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransfer| CBlockTransfer| - //#########################################| Dim| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise| Weight| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| ClusterLengths| ScalarPerVector| - //#########################################| Spatial| | | | | Operation| Operation| Operation| Specialization| | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| MBlock_MPerBlock| NWaveNPerXdl| - //#########################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | NBlock_NPerBlock| | - DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Xdl_CShuffle< 3, BF16, F32, BF16, F32, PassThrough, PassThrough, PassThrough, ConvBwdWeightFilter1x1Stride1Pad0, 256, 256, 128, 4, 8, 32, 32, 4, 2, S<1, 4, 32, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, S<1, 4, 16, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 2, true, 1, 1, S<1, 32, 1, 8>, 4>, - DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Xdl_CShuffle< 3, BF16, F32, BF16, F32, PassThrough, PassThrough, PassThrough, ConvBwdWeightFilter1x1Stride1Pad0, 256, 128, 256, 4, 8, 32, 32, 2, 4, S<1, 4, 16, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 2, true, S<1, 4, 32, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, 1, 1, S<1, 32, 1, 8>, 4>, - DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Xdl_CShuffle< 3, BF16, F32, BF16, F32, PassThrough, PassThrough, PassThrough, ConvBwdWeightFilter1x1Stride1Pad0, 128, 128, 128, 4, 8, 32, 32, 4, 2, S<1, 4, 16, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, S<1, 4, 16, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, 1, 1, S<1, 32, 1, 4>, 4>, - DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Xdl_CShuffle< 3, BF16, F32, BF16, F32, PassThrough, PassThrough, PassThrough, ConvBwdWeightFilter1x1Stride1Pad0, 256, 128, 128, 4, 8, 32, 32, 2, 2, S<1, 4, 16, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 2, true, S<1, 4, 16, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 2, true, 1, 1, S<1, 32, 1, 4>, 4>, - DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Xdl_CShuffle< 3, BF16, F32, BF16, F32, PassThrough, PassThrough, PassThrough, ConvBwdWeightFilter1x1Stride1Pad0, 128, 128, 64, 4, 8, 32, 32, 2, 2, S<1, 4, 16, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, S<1, 4, 8, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 2, true, 1, 1, S<1, 32, 1, 4>, 4>, - DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Xdl_CShuffle< 3, BF16, F32, BF16, F32, PassThrough, PassThrough, PassThrough, ConvBwdWeightFilter1x1Stride1Pad0, 128, 64, 128, 4, 8, 32, 32, 2, 2, S<1, 4, 8, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 2, true, S<1, 4, 16, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, 1, 1, S<1, 32, 1, 4>, 4>, - DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Xdl_CShuffle< 3, BF16, F32, BF16, F32, PassThrough, PassThrough, PassThrough, ConvBwdWeightFilter1x1Stride1Pad0, 64, 64, 64, 4, 8, 32, 32, 2, 2, S<1, 4, 8, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, S<1, 4, 8, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, 1, 1, S<1, 16, 1, 4>, 4>, - DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Xdl_CShuffle< 3, BF16, F32, BF16, F32, PassThrough, PassThrough, PassThrough, ConvBwdWeightFilter1x1Stride1Pad0, 256, 128, 64, 4, 8, 32, 32, 2, 1, S<1, 4, 16, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 2, true, S<1, 4, 8, 8>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 1, true, 1, 1, S<1, 32, 1, 4>, 4>, - DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Xdl_CShuffle< 3, BF16, F32, BF16, F32, PassThrough, PassThrough, PassThrough, ConvBwdWeightFilter1x1Stride1Pad0, 256, 64, 128, 4, 8, 32, 32, 1, 2, S<1, 4, 8, 8>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 1, true, S<1, 4, 16, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 2, true, 1, 1, S<1, 32, 1, 4>, 4>, - DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Xdl_CShuffle< 3, BF16, F32, BF16, F32, PassThrough, PassThrough, PassThrough, ConvBwdWeightFilter1x1Stride1Pad0, 128, 128, 32, 4, 8, 32, 32, 2, 1, S<1, 4, 16, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, S<1, 4, 4, 8>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 1, true, 1, 1, S<1, 32, 1, 4>, 4>, - DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Xdl_CShuffle< 3, BF16, F32, BF16, F32, PassThrough, PassThrough, PassThrough, ConvBwdWeightFilter1x1Stride1Pad0, 128, 32, 128, 4, 8, 32, 32, 1, 2, S<1, 4, 4, 8>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 1, true, S<1, 4, 16, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, 1, 1, S<1, 32, 1, 4>, 4>, - DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Xdl_CShuffle< 3, BF16, F32, BF16, F32, PassThrough, PassThrough, PassThrough, ConvBwdWeightFilter1x1Stride1Pad0, 64, 64, 32, 4, 8, 32, 32, 2, 1, S<1, 4, 8, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, S<1, 4, 4, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 2, true, 1, 1, S<1, 16, 1, 4>, 4>, - DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Xdl_CShuffle< 3, BF16, F32, BF16, F32, PassThrough, PassThrough, PassThrough, ConvBwdWeightFilter1x1Stride1Pad0, 64, 32, 64, 4, 8, 32, 32, 1, 2, S<1, 4, 4, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 2, true, S<1, 4, 8, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, 1, 1, S<1, 16, 1, 4>, 4> - // clang-format on - >; - void add_device_grouped_conv3d_bwd_weight_xdl_gndhwc_gkzyxc_gndhwk_bf16_f32_bf16_instances( std::vector>>& instances) { + // 1. Default add_device_operation_instances( instances, - device_grouped_conv3d_bwd_weight_xdl_c_shuffle_gndhwc_gkzyxc_gndhwk_bf16_f32_bf16_instances{}); - add_device_operation_instances( - instances, - device_grouped_conv3d_bwd_weight_xdl_gndhwc_gkzyxc_gndhwk_1x1_s1_p0_bf16_f32_bf16_instances{}); + device_grouped_conv_bwd_weight_xdl_c_shuffle_bf16_instances<3, + GNDHWC, + GKZYXC, + GNDHWK, + ConvBwdWeightDefault>{}); + // 2. Filter1x1Stride1Pad0 + add_device_operation_instances(instances, + device_grouped_conv_bwd_weight_xdl_c_shuffle_bf16_instances< + 3, + GNDHWC, + GKZYXC, + GNDHWK, + ConvBwdWeightFilter1x1Stride1Pad0>{}); } } // namespace instance diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_weight/device_grouped_conv3d_bwd_weight_xdl_gndhwc_gkzyxc_gndhwk_f16_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_weight/device_grouped_conv3d_bwd_weight_xdl_gndhwc_gkzyxc_gndhwk_f16_instance.cpp index 1655850ec148568734cbdbace8eefe8cb6487e0c..099123ecbc7695c2cdd3bd1c8e28113aab4d0227 100644 --- a/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_weight/device_grouped_conv3d_bwd_weight_xdl_gndhwc_gkzyxc_gndhwk_f16_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_weight/device_grouped_conv3d_bwd_weight_xdl_gndhwc_gkzyxc_gndhwk_f16_instance.cpp @@ -1,85 +1,21 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #include #include "ck/ck.hpp" #include "ck/tensor_operation/gpu/device/tensor_layout.hpp" -#include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_gnwc_gkxc_gnwk_xdl_cshuffle.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_xdl_cshuffle.hpp" #include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" #include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" +#include "ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_weight/device_grouped_conv_bwd_weight_xdl_instance.hpp" namespace ck { namespace tensor_operation { namespace device { namespace instance { -using F16 = ck::half_t; -using F32 = float; - -template -using S = ck::Sequence; - -using GNDHWC = ck::tensor_layout::convolution::GNDHWC; -using GKZYXC = ck::tensor_layout::convolution::GKZYXC; -using GNDHWK = ck::tensor_layout::convolution::GNDHWK; - -using PassThrough = ck::tensor_operation::element_wise::PassThrough; - -static constexpr auto ConvBwdWeightDefault = - ck::tensor_operation::device::ConvolutionBackwardWeightSpecialization::Default; - -static constexpr auto ConvBwdWeightFilter1x1Stride1Pad0 = - ck::tensor_operation::device::ConvolutionBackwardWeightSpecialization::Filter1x1Stride1Pad0; - -// Compilation parameters for in[n, di, hi, wi, c] * wei[k, z, y, x, c] = out[n, do, ho, wo, k] -using device_grouped_conv3d_bwd_weight_xdl_c_shuffle_gndhwc_gkzyxc_gndhwk_f16_default_instances = - std::tuple< - // clang-format off - //#########################################| Num| InData| WeiData| OutData| AccData| In| Wei| Out| ConvBackward| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransfer| CBlockTransfer| - //#########################################| Dim| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise| Weight| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| ClusterLengths| ScalarPerVector| - //#########################################| Spatial| | | | | Operation| Operation| Operation| Specialization| | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| MBlock_MPerBlock| NWaveNPerXdl| - //#########################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | NBlock_NPerBlock| | - DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Xdl_CShuffle< 3, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvBwdWeightDefault, 256, 256, 128, 4, 8, 32, 32, 4, 2, S<1, 4, 32, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, S<1, 4, 16, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 2, true, 1, 1, S<1, 32, 1, 8>, 8>, - DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Xdl_CShuffle< 3, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvBwdWeightDefault, 256, 128, 256, 4, 8, 32, 32, 2, 4, S<1, 4, 16, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 2, true, S<1, 4, 32, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, 1, 1, S<1, 32, 1, 8>, 8>, - DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Xdl_CShuffle< 3, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvBwdWeightDefault, 128, 128, 128, 4, 8, 32, 32, 4, 2, S<1, 4, 16, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, S<1, 4, 16, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, 1, 1, S<1, 32, 1, 4>, 8>, - DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Xdl_CShuffle< 3, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvBwdWeightDefault, 256, 128, 128, 4, 8, 32, 32, 2, 2, S<1, 4, 16, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 2, true, S<1, 4, 16, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 2, true, 1, 1, S<1, 32, 1, 4>, 8>, - DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Xdl_CShuffle< 3, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvBwdWeightDefault, 128, 128, 64, 4, 8, 32, 32, 2, 2, S<1, 4, 16, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, S<1, 4, 8, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 2, true, 1, 1, S<1, 32, 1, 4>, 8>, - DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Xdl_CShuffle< 3, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvBwdWeightDefault, 128, 64, 128, 4, 8, 32, 32, 2, 2, S<1, 4, 8, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 2, true, S<1, 4, 16, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, 1, 1, S<1, 32, 1, 4>, 8>, - DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Xdl_CShuffle< 3, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvBwdWeightDefault, 64, 64, 64, 4, 8, 32, 32, 2, 2, S<1, 4, 8, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, S<1, 4, 8, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, 1, 1, S<1, 16, 1, 4>, 8>, - DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Xdl_CShuffle< 3, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvBwdWeightDefault, 256, 128, 64, 4, 8, 32, 32, 2, 1, S<1, 4, 16, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 2, true, S<1, 4, 8, 8>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 1, true, 1, 1, S<1, 32, 1, 4>, 8>, - DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Xdl_CShuffle< 3, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvBwdWeightDefault, 256, 64, 128, 4, 8, 32, 32, 1, 2, S<1, 4, 8, 8>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 1, true, S<1, 4, 16, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 2, true, 1, 1, S<1, 32, 1, 4>, 8>, - DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Xdl_CShuffle< 3, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvBwdWeightDefault, 128, 128, 32, 4, 8, 32, 32, 2, 1, S<1, 4, 16, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, S<1, 4, 4, 8>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 1, true, 1, 1, S<1, 32, 1, 4>, 8>, - DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Xdl_CShuffle< 3, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvBwdWeightDefault, 128, 32, 128, 4, 8, 32, 32, 1, 2, S<1, 4, 4, 8>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 1, true, S<1, 4, 16, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, 1, 1, S<1, 32, 1, 4>, 8>, - DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Xdl_CShuffle< 3, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvBwdWeightDefault, 64, 64, 32, 4, 8, 32, 32, 2, 1, S<1, 4, 8, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, S<1, 4, 4, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 2, true, 1, 1, S<1, 16, 1, 4>, 8>, - DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Xdl_CShuffle< 3, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvBwdWeightDefault, 64, 32, 64, 4, 8, 32, 32, 1, 2, S<1, 4, 4, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 2, true, S<1, 4, 8, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, 1, 1, S<1, 16, 1, 4>, 8> - // clang-format on - >; - -using device_grouped_conv3d_bwd_weight_xdl_gndhwc_gkzyxc_gndhwk_1x1_s1_p0_f16_instances = - std::tuple< - // clang-format off - //#########################################| Num| InData| WeiData| OutData| AccData| In| Wei| Out| ConvBackward| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransfer| CBlockTransfer| - //#########################################| Dim| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise| Weight| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| ClusterLengths| ScalarPerVector| - //#########################################| Spatial| | | | | Operation| Operation| Operation| Specialization| | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| MBlock_MPerBlock| NWaveNPerXdl| - //#########################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | NBlock_NPerBlock| | - DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Xdl_CShuffle< 3, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvBwdWeightFilter1x1Stride1Pad0, 256, 256, 128, 4, 8, 32, 32, 4, 2, S<1, 4, 32, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, S<1, 4, 16, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 2, true, 1, 1, S<1, 32, 1, 8>, 8>, - DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Xdl_CShuffle< 3, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvBwdWeightFilter1x1Stride1Pad0, 256, 128, 256, 4, 8, 32, 32, 2, 4, S<1, 4, 16, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 2, true, S<1, 4, 32, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, 1, 1, S<1, 32, 1, 8>, 8>, - DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Xdl_CShuffle< 3, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvBwdWeightFilter1x1Stride1Pad0, 128, 128, 128, 4, 8, 32, 32, 4, 2, S<1, 4, 16, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, S<1, 4, 16, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, 1, 1, S<1, 32, 1, 4>, 8>, - DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Xdl_CShuffle< 3, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvBwdWeightFilter1x1Stride1Pad0, 256, 128, 128, 4, 8, 32, 32, 2, 2, S<1, 4, 16, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 2, true, S<1, 4, 16, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 2, true, 1, 1, S<1, 32, 1, 4>, 8>, - DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Xdl_CShuffle< 3, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvBwdWeightFilter1x1Stride1Pad0, 128, 128, 64, 4, 8, 32, 32, 2, 2, S<1, 4, 16, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, S<1, 4, 8, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 2, true, 1, 1, S<1, 32, 1, 4>, 8>, - DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Xdl_CShuffle< 3, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvBwdWeightFilter1x1Stride1Pad0, 128, 64, 128, 4, 8, 32, 32, 2, 2, S<1, 4, 8, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 2, true, S<1, 4, 16, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, 1, 1, S<1, 32, 1, 4>, 8>, - DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Xdl_CShuffle< 3, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvBwdWeightFilter1x1Stride1Pad0, 64, 64, 64, 4, 8, 32, 32, 2, 2, S<1, 4, 8, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, S<1, 4, 8, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, 1, 1, S<1, 16, 1, 4>, 8>, - DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Xdl_CShuffle< 3, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvBwdWeightFilter1x1Stride1Pad0, 256, 128, 64, 4, 8, 32, 32, 2, 1, S<1, 4, 16, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 2, true, S<1, 4, 8, 8>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 1, true, 1, 1, S<1, 32, 1, 4>, 8>, - DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Xdl_CShuffle< 3, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvBwdWeightFilter1x1Stride1Pad0, 256, 64, 128, 4, 8, 32, 32, 1, 2, S<1, 4, 8, 8>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 1, true, S<1, 4, 16, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 2, true, 1, 1, S<1, 32, 1, 4>, 8>, - DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Xdl_CShuffle< 3, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvBwdWeightFilter1x1Stride1Pad0, 128, 128, 32, 4, 8, 32, 32, 2, 1, S<1, 4, 16, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, S<1, 4, 4, 8>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 1, true, 1, 1, S<1, 32, 1, 4>, 8>, - DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Xdl_CShuffle< 3, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvBwdWeightFilter1x1Stride1Pad0, 128, 32, 128, 4, 8, 32, 32, 1, 2, S<1, 4, 4, 8>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 1, true, S<1, 4, 16, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, 1, 1, S<1, 32, 1, 4>, 8>, - DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Xdl_CShuffle< 3, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvBwdWeightFilter1x1Stride1Pad0, 64, 64, 32, 4, 8, 32, 32, 2, 1, S<1, 4, 8, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, S<1, 4, 4, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 2, true, 1, 1, S<1, 16, 1, 4>, 8>, - DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Xdl_CShuffle< 3, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvBwdWeightFilter1x1Stride1Pad0, 64, 32, 64, 4, 8, 32, 32, 1, 2, S<1, 4, 4, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 2, true, S<1, 4, 8, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, 1, 1, S<1, 16, 1, 4>, 8> - // clang-format on - >; - void add_device_grouped_conv3d_bwd_weight_xdl_gndhwc_gkzyxc_gndhwk_f16_instances( std::vector>>& instances) { + // 1. Default add_device_operation_instances( instances, - device_grouped_conv3d_bwd_weight_xdl_c_shuffle_gndhwc_gkzyxc_gndhwk_f16_default_instances{}); - add_device_operation_instances( - instances, - device_grouped_conv3d_bwd_weight_xdl_gndhwc_gkzyxc_gndhwk_1x1_s1_p0_f16_instances{}); + device_grouped_conv_bwd_weight_xdl_c_shuffle_f16_instances<3, + GNDHWC, + GKZYXC, + GNDHWK, + ConvBwdWeightDefault>{}); + // 2. Filter1x1Stride1Pad0 + add_device_operation_instances(instances, + device_grouped_conv_bwd_weight_xdl_c_shuffle_f16_instances< + 3, + GNDHWC, + GKZYXC, + GNDHWK, + ConvBwdWeightFilter1x1Stride1Pad0>{}); } } // namespace instance diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_weight/device_grouped_conv3d_bwd_weight_xdl_gndhwc_gkzyxc_gndhwk_f32_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_weight/device_grouped_conv3d_bwd_weight_xdl_gndhwc_gkzyxc_gndhwk_f32_instance.cpp index aba46b7ebeb90afb8881fc7ad4d8a2e1281eb146..0eda980b30bae334904b2ab93935ca9e6703a04e 100644 --- a/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_weight/device_grouped_conv3d_bwd_weight_xdl_gndhwc_gkzyxc_gndhwk_f32_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_weight/device_grouped_conv3d_bwd_weight_xdl_gndhwc_gkzyxc_gndhwk_f32_instance.cpp @@ -1,84 +1,21 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #include #include "ck/ck.hpp" #include "ck/tensor_operation/gpu/device/tensor_layout.hpp" -#include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_gnwc_gkxc_gnwk_xdl_cshuffle.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_xdl_cshuffle.hpp" #include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" #include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" +#include "ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_weight/device_grouped_conv_bwd_weight_xdl_instance.hpp" namespace ck { namespace tensor_operation { namespace device { namespace instance { -using F32 = float; - -template -using S = ck::Sequence; - -using GNDHWC = ck::tensor_layout::convolution::GNDHWC; -using GKZYXC = ck::tensor_layout::convolution::GKZYXC; -using GNDHWK = ck::tensor_layout::convolution::GNDHWK; - -using PassThrough = ck::tensor_operation::element_wise::PassThrough; - -static constexpr auto ConvBwdWeightDefault = - ck::tensor_operation::device::ConvolutionBackwardWeightSpecialization::Default; - -static constexpr auto ConvBwdWeightFilter1x1Stride1Pad0 = - ck::tensor_operation::device::ConvolutionBackwardWeightSpecialization::Filter1x1Stride1Pad0; - -// Compilation parameters for in[n, di, hi, wi, c] * wei[k, z, y, x, c] = out[n, do, ho, wo, k] -using device_grouped_conv3d_bwd_weight_xdl_c_shuffle_gndhwc_gkzyxc_gndhwk_f32_default_instances = - std::tuple< - // clang-format off - //#########################################| Num| InData| WeiData| OutData| AccData| In| Wei| Out| ConvBackward| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransfer| CBlockTransfer| - //#########################################| Dim| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise| Weight| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| ClusterLengths| ScalarPerVector| - //#########################################| Spatial| | | | | Operation| Operation| Operation| Specialization| | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| MBlock_MPerBlock| NWaveNPerXdl| - //#########################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | NBlock_NPerBlock| | - DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Xdl_CShuffle< 3, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvBwdWeightDefault, 256, 256, 128, 4, 4, 32, 32, 4, 2, S<1, 4, 64, 1>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 4, 4, true, S<1, 4, 32, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 4, 2, true, 1, 1, S<1, 32, 1, 8>, 4>, - DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Xdl_CShuffle< 3, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvBwdWeightDefault, 256, 128, 256, 4, 4, 32, 32, 2, 4, S<1, 4, 32, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 4, 2, true, S<1, 4, 64, 1>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 4, 4, true, 1, 1, S<1, 32, 1, 8>, 4>, - DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Xdl_CShuffle< 3, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvBwdWeightDefault, 128, 128, 128, 4, 4, 32, 32, 4, 2, S<1, 4, 32, 1>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 4, 4, true, S<1, 4, 32, 1>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 4, 4, true, 1, 1, S<1, 32, 1, 4>, 4>, - DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Xdl_CShuffle< 3, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvBwdWeightDefault, 256, 128, 128, 4, 4, 32, 32, 2, 2, S<1, 4, 32, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 4, 2, true, S<1, 4, 32, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 4, 2, true, 1, 1, S<1, 32, 1, 4>, 4>, - DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Xdl_CShuffle< 3, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvBwdWeightDefault, 128, 128, 64, 4, 4, 32, 32, 2, 2, S<1, 4, 32, 1>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 4, 4, true, S<1, 4, 16, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 4, 2, true, 1, 1, S<1, 32, 1, 4>, 4>, - DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Xdl_CShuffle< 3, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvBwdWeightDefault, 128, 64, 128, 4, 4, 32, 32, 2, 2, S<1, 4, 16, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 4, 2, true, S<1, 4, 32, 1>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 4, 4, true, 1, 1, S<1, 32, 1, 4>, 4>, - DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Xdl_CShuffle< 3, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvBwdWeightDefault, 64, 64, 64, 4, 4, 32, 32, 2, 2, S<1, 4, 16, 1>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 4, 4, true, S<1, 4, 16, 1>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 4, 4, true, 1, 1, S<1, 16, 1, 4>, 4>, - DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Xdl_CShuffle< 3, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvBwdWeightDefault, 256, 128, 64, 4, 4, 32, 32, 2, 1, S<1, 4, 32, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 4, 2, true, S<1, 4, 16, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 4, 1, true, 1, 1, S<1, 32, 1, 4>, 4>, - DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Xdl_CShuffle< 3, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvBwdWeightDefault, 256, 64, 128, 4, 4, 32, 32, 1, 2, S<1, 4, 16, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 4, 1, true, S<1, 4, 32, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 4, 2, true, 1, 1, S<1, 32, 1, 4>, 4>, - DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Xdl_CShuffle< 3, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvBwdWeightDefault, 128, 128, 32, 4, 4, 32, 32, 2, 1, S<1, 4, 32, 1>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 4, 4, true, S<1, 4, 8, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 4, 1, true, 1, 1, S<1, 32, 1, 4>, 4>, - DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Xdl_CShuffle< 3, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvBwdWeightDefault, 128, 32, 128, 4, 4, 32, 32, 1, 2, S<1, 4, 8, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 4, 1, true, S<1, 4, 32, 1>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 4, 4, true, 1, 1, S<1, 32, 1, 4>, 4>, - DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Xdl_CShuffle< 3, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvBwdWeightDefault, 64, 64, 32, 4, 4, 32, 32, 2, 1, S<1, 4, 16, 1>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 4, 4, true, S<1, 4, 8, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 4, 2, true, 1, 1, S<1, 16, 1, 4>, 4>, - DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Xdl_CShuffle< 3, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvBwdWeightDefault, 64, 32, 64, 4, 4, 32, 32, 1, 2, S<1, 4, 8, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 4, 2, true, S<1, 4, 16, 1>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 4, 4, true, 1, 1, S<1, 16, 1, 4>, 4> - // clang-format on - >; - -using device_grouped_conv3d_bwd_weight_xdl_gndhwc_gkzyxc_gndhwk_1x1_s1_p0_f32_instances = - std::tuple< - // clang-format off - //#########################################| Num| InData| WeiData| OutData| AccData| In| Wei| Out| ConvBackward| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransfer| CBlockTransfer| - //#########################################| Dim| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise| Weight| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| ClusterLengths| ScalarPerVector| - //#########################################| Spatial| | | | | Operation| Operation| Operation| Specialization| | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| MBlock_MPerBlock| NWaveNPerXdl| - //#########################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | NBlock_NPerBlock| | - DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Xdl_CShuffle< 3, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvBwdWeightFilter1x1Stride1Pad0, 256, 256, 128, 4, 4, 32, 32, 4, 2, S<1, 4, 64, 1>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 4, 4, true, S<1, 4, 32, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 4, 2, true, 1, 1, S<1, 32, 1, 8>, 4>, - DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Xdl_CShuffle< 3, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvBwdWeightFilter1x1Stride1Pad0, 256, 128, 256, 4, 4, 32, 32, 2, 4, S<1, 4, 32, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 4, 2, true, S<1, 4, 64, 1>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 4, 4, true, 1, 1, S<1, 32, 1, 8>, 4>, - DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Xdl_CShuffle< 3, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvBwdWeightFilter1x1Stride1Pad0, 128, 128, 128, 4, 4, 32, 32, 4, 2, S<1, 4, 32, 1>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 4, 4, true, S<1, 4, 32, 1>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 4, 4, true, 1, 1, S<1, 32, 1, 4>, 4>, - DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Xdl_CShuffle< 3, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvBwdWeightFilter1x1Stride1Pad0, 256, 128, 128, 4, 4, 32, 32, 2, 2, S<1, 4, 32, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 4, 2, true, S<1, 4, 32, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 4, 2, true, 1, 1, S<1, 32, 1, 4>, 4>, - DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Xdl_CShuffle< 3, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvBwdWeightFilter1x1Stride1Pad0, 128, 128, 64, 4, 4, 32, 32, 2, 2, S<1, 4, 32, 1>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 4, 4, true, S<1, 4, 16, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 4, 2, true, 1, 1, S<1, 32, 1, 4>, 4>, - DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Xdl_CShuffle< 3, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvBwdWeightFilter1x1Stride1Pad0, 128, 64, 128, 4, 4, 32, 32, 2, 2, S<1, 4, 16, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 4, 2, true, S<1, 4, 32, 1>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 4, 4, true, 1, 1, S<1, 32, 1, 4>, 4>, - DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Xdl_CShuffle< 3, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvBwdWeightFilter1x1Stride1Pad0, 64, 64, 64, 4, 4, 32, 32, 2, 2, S<1, 4, 16, 1>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 4, 4, true, S<1, 4, 16, 1>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 4, 4, true, 1, 1, S<1, 16, 1, 4>, 4>, - DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Xdl_CShuffle< 3, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvBwdWeightFilter1x1Stride1Pad0, 256, 128, 64, 4, 4, 32, 32, 2, 1, S<1, 4, 32, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 4, 2, true, S<1, 4, 16, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 4, 1, true, 1, 1, S<1, 32, 1, 4>, 4>, - DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Xdl_CShuffle< 3, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvBwdWeightFilter1x1Stride1Pad0, 256, 64, 128, 4, 4, 32, 32, 1, 2, S<1, 4, 16, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 4, 1, true, S<1, 4, 32, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 4, 2, true, 1, 1, S<1, 32, 1, 4>, 4>, - DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Xdl_CShuffle< 3, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvBwdWeightFilter1x1Stride1Pad0, 128, 128, 32, 4, 4, 32, 32, 2, 1, S<1, 4, 32, 1>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 4, 4, true, S<1, 4, 8, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 4, 1, true, 1, 1, S<1, 32, 1, 4>, 4>, - DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Xdl_CShuffle< 3, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvBwdWeightFilter1x1Stride1Pad0, 128, 32, 128, 4, 4, 32, 32, 1, 2, S<1, 4, 8, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 4, 1, true, S<1, 4, 32, 1>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 4, 4, true, 1, 1, S<1, 32, 1, 4>, 4>, - DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Xdl_CShuffle< 3, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvBwdWeightFilter1x1Stride1Pad0, 64, 64, 32, 4, 4, 32, 32, 2, 1, S<1, 4, 16, 1>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 4, 4, true, S<1, 4, 8, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 4, 2, true, 1, 1, S<1, 16, 1, 4>, 4>, - DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Xdl_CShuffle< 3, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvBwdWeightFilter1x1Stride1Pad0, 64, 32, 64, 4, 4, 32, 32, 1, 2, S<1, 4, 8, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 4, 2, true, S<1, 4, 16, 1>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 4, 4, true, 1, 1, S<1, 16, 1, 4>, 4> - // clang-format on - >; - void add_device_grouped_conv3d_bwd_weight_xdl_gndhwc_gkzyxc_gndhwk_f32_instances( std::vector>>& instances) { + // 1. Default add_device_operation_instances( instances, - device_grouped_conv3d_bwd_weight_xdl_c_shuffle_gndhwc_gkzyxc_gndhwk_f32_default_instances{}); - add_device_operation_instances( - instances, - device_grouped_conv3d_bwd_weight_xdl_gndhwc_gkzyxc_gndhwk_1x1_s1_p0_f32_instances{}); + device_grouped_conv_bwd_weight_xdl_c_shuffle_f32_instances<3, + GNDHWC, + GKZYXC, + GNDHWK, + ConvBwdWeightDefault>{}); + // 2. Filter1x1Stride1Pad0 + add_device_operation_instances(instances, + device_grouped_conv_bwd_weight_xdl_c_shuffle_f32_instances< + 3, + GNDHWC, + GKZYXC, + GNDHWK, + ConvBwdWeightFilter1x1Stride1Pad0>{}); } } // namespace instance diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_weight/device_grouped_conv3d_bwd_weight_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_weight/device_grouped_conv3d_bwd_weight_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_instance.cpp new file mode 100644 index 0000000000000000000000000000000000000000..1e5c1946fbc2db558e0ddc16b5f2d5fd0655fd1b --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_weight/device_grouped_conv3d_bwd_weight_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_instance.cpp @@ -0,0 +1,47 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" + +#include "ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_weight/device_grouped_conv_bwd_weight_xdl_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +// Compilation parameters for in[n, hi, wi, g, c] * wei[g, k, y, x, c] = out[n, ho, wo, g, k] +void add_device_grouped_conv3d_bwd_weight_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_f32_bf16_instances( + std::vector>>& instances) +{ + // 1. Default + add_device_operation_instances( + instances, + device_grouped_conv_bwd_weight_xdl_c_shuffle_bf16_instances<3, + NDHWGC, + GKZYXC, + NDHWGK, + ConvBwdWeightDefault>{}); + // 2. Filter1x1Stride1Pad0 + add_device_operation_instances(instances, + device_grouped_conv_bwd_weight_xdl_c_shuffle_bf16_instances< + 3, + NDHWGC, + GKZYXC, + NDHWGK, + ConvBwdWeightFilter1x1Stride1Pad0>{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_weight/device_grouped_conv3d_bwd_weight_xdl_ndhwgc_gkzyxc_ndhwgk_f16_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_weight/device_grouped_conv3d_bwd_weight_xdl_ndhwgc_gkzyxc_ndhwgk_f16_instance.cpp new file mode 100644 index 0000000000000000000000000000000000000000..e0b442bf2457fd33f51004ef810e4d2e690f5aec --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_weight/device_grouped_conv3d_bwd_weight_xdl_ndhwgc_gkzyxc_ndhwgk_f16_instance.cpp @@ -0,0 +1,47 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" + +#include "ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_weight/device_grouped_conv_bwd_weight_xdl_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +// Compilation parameters for in[n, hi, wi, g, c] * wei[g, k, y, x, c] = out[n, ho, wo, g, k] +void add_device_grouped_conv3d_bwd_weight_xdl_ndhwgc_gkzyxc_ndhwgk_f16_instances( + std::vector>>& instances) +{ + // 1. Default + add_device_operation_instances( + instances, + device_grouped_conv_bwd_weight_xdl_c_shuffle_f16_instances<3, + NDHWGC, + GKZYXC, + NDHWGK, + ConvBwdWeightDefault>{}); + // 2. Filter1x1Stride1Pad0 + add_device_operation_instances(instances, + device_grouped_conv_bwd_weight_xdl_c_shuffle_f16_instances< + 3, + NDHWGC, + GKZYXC, + NDHWGK, + ConvBwdWeightFilter1x1Stride1Pad0>{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_weight/device_grouped_conv3d_bwd_weight_xdl_ndhwgc_gkzyxc_ndhwgk_f32_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_weight/device_grouped_conv3d_bwd_weight_xdl_ndhwgc_gkzyxc_ndhwgk_f32_instance.cpp new file mode 100644 index 0000000000000000000000000000000000000000..4bb7948245e39fcd07b6d7385cfa8cc6888fa4d6 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_weight/device_grouped_conv3d_bwd_weight_xdl_ndhwgc_gkzyxc_ndhwgk_f32_instance.cpp @@ -0,0 +1,47 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" + +#include "ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_weight/device_grouped_conv_bwd_weight_xdl_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +// Compilation parameters for in[n, hi, wi, g, c] * wei[g, k, y, x, c] = out[n, ho, wo, g, k] +void add_device_grouped_conv3d_bwd_weight_xdl_ndhwgc_gkzyxc_ndhwgk_f32_instances( + std::vector>>& instances) +{ + // 1. Default + add_device_operation_instances( + instances, + device_grouped_conv_bwd_weight_xdl_c_shuffle_f32_instances<3, + NDHWGC, + GKZYXC, + NDHWGK, + ConvBwdWeightDefault>{}); + // 2. Filter1x1Stride1Pad0 + add_device_operation_instances(instances, + device_grouped_conv_bwd_weight_xdl_c_shuffle_f32_instances< + 3, + NDHWGC, + GKZYXC, + NDHWGK, + ConvBwdWeightFilter1x1Stride1Pad0>{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/CMakeLists.txt b/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/CMakeLists.txt index 90efc09ee75f692a0664f3869cf0a6b18a7c2ce1..cd209dbf9e11e0031848fac0148f8e7e848ed480 100644 --- a/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/CMakeLists.txt +++ b/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/CMakeLists.txt @@ -4,8 +4,8 @@ add_instance_library(device_grouped_conv3d_fwd_instance device_grouped_conv3d_fwd_xdl_gndhwc_gkzyxc_gndhwk_f32_instance.cpp device_grouped_conv3d_fwd_xdl_gndhwc_gkzyxc_gndhwk_int8_instance.cpp - device_grouped_conv3d_fwd_xdl_ndhwgc_kzyxgc_ndhwgk_bf16_instance.cpp - device_grouped_conv3d_fwd_xdl_ndhwgc_kzyxgc_ndhwgk_f16_instance.cpp - device_grouped_conv3d_fwd_xdl_ndhwgc_kzyxgc_ndhwgk_f32_instance.cpp - device_grouped_conv3d_fwd_xdl_ndhwgc_kzyxgc_ndhwgk_int8_instance.cpp + device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_instance.cpp + device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_f16_instance.cpp + device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_f32_instance.cpp + device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_int8_instance.cpp ) diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/device_grouped_conv3d_fwd_xdl_gndhwc_gkzyxc_gndhwk_bf16_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/device_grouped_conv3d_fwd_xdl_gndhwc_gkzyxc_gndhwk_bf16_instance.cpp index b4ae8b6ce5f2ba3d8f3569449934d84fc631252c..bd8443f5ec7d1c64421be759415701b8ffea7f9c 100644 --- a/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/device_grouped_conv3d_fwd_xdl_gndhwc_gkzyxc_gndhwk_bf16_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/device_grouped_conv3d_fwd_xdl_gndhwc_gkzyxc_gndhwk_bf16_instance.cpp @@ -1,15 +1,7 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. - -#include - -#include "ck/ck.hpp" -#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" -#include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_d_xdl_cshuffle.hpp" -#include "ck/tensor_operation/gpu/device/convolution_forward_specialization.hpp" -#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" -#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. +#include "ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_instance.hpp" #include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" namespace ck { @@ -17,94 +9,6 @@ namespace tensor_operation { namespace device { namespace instance { -using BF16 = ck::bhalf_t; -using F32 = float; - -using Empty_Tuple = ck::Tuple<>; - -template -using S = ck::Sequence; - -using GNDHWC = ck::tensor_layout::convolution::GNDHWC; -using GKZYXC = ck::tensor_layout::convolution::GKZYXC; -using GNDHWK = ck::tensor_layout::convolution::GNDHWK; - -using PassThrough = ck::tensor_operation::element_wise::PassThrough; - -static constexpr auto ConvFwdDefault = - ck::tensor_operation::device::ConvolutionForwardSpecialization::Default; - -static constexpr auto ConvFwd1x1P0 = - ck::tensor_operation::device::ConvolutionForwardSpecialization::Filter1x1Pad0; - -static constexpr auto ConvFwd1x1S1P0 = - ck::tensor_operation::device::ConvolutionForwardSpecialization::Filter1x1Stride1Pad0; - -static constexpr auto GemmMNKPadding = ck::tensor_operation::device::GemmSpecialization::MNKPadding; - -// in[g, n, di, hi, wi, c] * wei[g, k, z, y, x, c] = out[g, n, do, ho, wo, k] -using device_grouped_conv3d_fwd_xdl_gndhwc_gkzyxc_gndhwk_bf16_instances = - std::tuple< - // clang-format off - // Default - //########################################| NumDim| A| B| Ds| E| AData| BData| AccData| CShuffle| Ds| EData| A| B| CDE| ConvForward| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| - //########################################| Spatial| Layout| Layout| Layout| Layout| Type| Type| Type| DataType| DataType| Type| Elementwise| Elementwise| Elementwise| Specialization| Specialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| - //########################################| | | | | | | | | | | | Operation| Operation| Operation| | | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| - //########################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | - DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 3, GNDHWC, GKZYXC, Empty_Tuple, GNDHWK, BF16, BF16, F32, BF16, Empty_Tuple, BF16, PassThrough, PassThrough, PassThrough, ConvFwdDefault, GemmMNKPadding, 1, 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, - DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 3, GNDHWC, GKZYXC, Empty_Tuple, GNDHWK, BF16, BF16, F32, BF16, Empty_Tuple, BF16, PassThrough, PassThrough, PassThrough, ConvFwdDefault, GemmMNKPadding, 1, 256, 128, 256, 32, 8, 8, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, - DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 3, GNDHWC, GKZYXC, Empty_Tuple, GNDHWK, BF16, BF16, F32, BF16, Empty_Tuple, BF16, PassThrough, PassThrough, PassThrough, ConvFwdDefault, GemmMNKPadding, 1, 128, 128, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 8>, - DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 3, GNDHWC, GKZYXC, Empty_Tuple, GNDHWK, BF16, BF16, F32, BF16, Empty_Tuple, BF16, PassThrough, PassThrough, PassThrough, ConvFwdDefault, GemmMNKPadding, 1, 256, 128, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, - DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 3, GNDHWC, GKZYXC, Empty_Tuple, GNDHWK, BF16, BF16, F32, BF16, Empty_Tuple, BF16, PassThrough, PassThrough, PassThrough, ConvFwdDefault, GemmMNKPadding, 1, 128, 128, 64, 32, 8, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 8>, - DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 3, GNDHWC, GKZYXC, Empty_Tuple, GNDHWK, BF16, BF16, F32, BF16, Empty_Tuple, BF16, PassThrough, PassThrough, PassThrough, ConvFwdDefault, GemmMNKPadding, 1, 128, 64, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 8>, - DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 3, GNDHWC, GKZYXC, Empty_Tuple, GNDHWK, BF16, BF16, F32, BF16, Empty_Tuple, BF16, PassThrough, PassThrough, PassThrough, ConvFwdDefault, GemmMNKPadding, 1, 64, 64, 64, 32, 8, 8, 32, 32, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 8>, - DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 3, GNDHWC, GKZYXC, Empty_Tuple, GNDHWK, BF16, BF16, F32, BF16, Empty_Tuple, BF16, PassThrough, PassThrough, PassThrough, ConvFwdDefault, GemmMNKPadding, 1, 256, 128, 64, 32, 8, 8, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, - DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 3, GNDHWC, GKZYXC, Empty_Tuple, GNDHWK, BF16, BF16, F32, BF16, Empty_Tuple, BF16, PassThrough, PassThrough, PassThrough, ConvFwdDefault, GemmMNKPadding, 1, 256, 64, 128, 32, 8, 8, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, - DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 3, GNDHWC, GKZYXC, Empty_Tuple, GNDHWK, BF16, BF16, F32, BF16, Empty_Tuple, BF16, PassThrough, PassThrough, PassThrough, ConvFwdDefault, GemmMNKPadding, 1, 128, 128, 32, 32, 8, 8, 32, 32, 2, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 8>, - DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 3, GNDHWC, GKZYXC, Empty_Tuple, GNDHWK, BF16, BF16, F32, BF16, Empty_Tuple, BF16, PassThrough, PassThrough, PassThrough, ConvFwdDefault, GemmMNKPadding, 1, 128, 32, 128, 32, 8, 8, 32, 32, 1, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 8>, - DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 3, GNDHWC, GKZYXC, Empty_Tuple, GNDHWK, BF16, BF16, F32, BF16, Empty_Tuple, BF16, PassThrough, PassThrough, PassThrough, ConvFwdDefault, GemmMNKPadding, 1, 64, 64, 32, 32, 8, 8, 32, 32, 2, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 8>, - DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 3, GNDHWC, GKZYXC, Empty_Tuple, GNDHWK, BF16, BF16, F32, BF16, Empty_Tuple, BF16, PassThrough, PassThrough, PassThrough, ConvFwdDefault, GemmMNKPadding, 1, 64, 32, 64, 32, 8, 8, 32, 32, 1, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 8>, - - // Filter1x1Pad0 - //########################################| NumDim| A| B| Ds| E| AData| BData| AccData| CShuffle| Ds| EData| A| B| CDE| ConvForward| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| - //########################################| Spatial| Layout| Layout| Layout| Layout| Type| Type| Type| DataType| DataType| Type| Elementwise| Elementwise| Elementwise| Specialization| Specialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| - //########################################| | | | | | | | | | | | Operation| Operation| Operation| | | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| - //########################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | - DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 3, GNDHWC, GKZYXC, Empty_Tuple, GNDHWK, BF16, BF16, F32, BF16, Empty_Tuple, BF16, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, GemmMNKPadding, 1, 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, - DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 3, GNDHWC, GKZYXC, Empty_Tuple, GNDHWK, BF16, BF16, F32, BF16, Empty_Tuple, BF16, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, GemmMNKPadding, 1, 256, 128, 256, 32, 8, 8, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, - DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 3, GNDHWC, GKZYXC, Empty_Tuple, GNDHWK, BF16, BF16, F32, BF16, Empty_Tuple, BF16, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, GemmMNKPadding, 1, 128, 128, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 8>, - DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 3, GNDHWC, GKZYXC, Empty_Tuple, GNDHWK, BF16, BF16, F32, BF16, Empty_Tuple, BF16, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, GemmMNKPadding, 1, 256, 128, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, - DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 3, GNDHWC, GKZYXC, Empty_Tuple, GNDHWK, BF16, BF16, F32, BF16, Empty_Tuple, BF16, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, GemmMNKPadding, 1, 128, 128, 64, 32, 8, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 8>, - DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 3, GNDHWC, GKZYXC, Empty_Tuple, GNDHWK, BF16, BF16, F32, BF16, Empty_Tuple, BF16, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, GemmMNKPadding, 1, 128, 64, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 8>, - DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 3, GNDHWC, GKZYXC, Empty_Tuple, GNDHWK, BF16, BF16, F32, BF16, Empty_Tuple, BF16, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, GemmMNKPadding, 1, 64, 64, 64, 32, 8, 8, 32, 32, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 8>, - DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 3, GNDHWC, GKZYXC, Empty_Tuple, GNDHWK, BF16, BF16, F32, BF16, Empty_Tuple, BF16, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, GemmMNKPadding, 1, 256, 128, 64, 32, 8, 8, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, - DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 3, GNDHWC, GKZYXC, Empty_Tuple, GNDHWK, BF16, BF16, F32, BF16, Empty_Tuple, BF16, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, GemmMNKPadding, 1, 256, 64, 128, 32, 8, 8, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, - DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 3, GNDHWC, GKZYXC, Empty_Tuple, GNDHWK, BF16, BF16, F32, BF16, Empty_Tuple, BF16, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, GemmMNKPadding, 1, 128, 128, 32, 32, 8, 8, 32, 32, 2, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 8>, - DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 3, GNDHWC, GKZYXC, Empty_Tuple, GNDHWK, BF16, BF16, F32, BF16, Empty_Tuple, BF16, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, GemmMNKPadding, 1, 128, 32, 128, 32, 8, 8, 32, 32, 1, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 8>, - DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 3, GNDHWC, GKZYXC, Empty_Tuple, GNDHWK, BF16, BF16, F32, BF16, Empty_Tuple, BF16, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, GemmMNKPadding, 1, 64, 64, 32, 32, 8, 8, 32, 32, 2, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 8>, - DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 3, GNDHWC, GKZYXC, Empty_Tuple, GNDHWK, BF16, BF16, F32, BF16, Empty_Tuple, BF16, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, GemmMNKPadding, 1, 64, 32, 64, 32, 8, 8, 32, 32, 1, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 8>, - - // Filter1x1Stride1Pad0 - //########################################| NumDim| A| B| Ds| E| AData| BData| AccData| CShuffle| Ds| EData| A| B| CDE| ConvForward| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| - //########################################| Spatial| Layout| Layout| Layout| Layout| Type| Type| Type| DataType| DataType| Type| Elementwise| Elementwise| Elementwise| Specialization| Specialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| - //########################################| | | | | | | | | | | | Operation| Operation| Operation| | | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| - //########################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | - DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 3, GNDHWC, GKZYXC, Empty_Tuple, GNDHWK, BF16, BF16, F32, BF16, Empty_Tuple, BF16, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, GemmMNKPadding, 1, 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, - DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 3, GNDHWC, GKZYXC, Empty_Tuple, GNDHWK, BF16, BF16, F32, BF16, Empty_Tuple, BF16, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, GemmMNKPadding, 1, 256, 128, 256, 32, 8, 8, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, - DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 3, GNDHWC, GKZYXC, Empty_Tuple, GNDHWK, BF16, BF16, F32, BF16, Empty_Tuple, BF16, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, GemmMNKPadding, 1, 128, 128, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 8>, - DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 3, GNDHWC, GKZYXC, Empty_Tuple, GNDHWK, BF16, BF16, F32, BF16, Empty_Tuple, BF16, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, GemmMNKPadding, 1, 256, 128, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, - DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 3, GNDHWC, GKZYXC, Empty_Tuple, GNDHWK, BF16, BF16, F32, BF16, Empty_Tuple, BF16, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, GemmMNKPadding, 1, 128, 128, 64, 32, 8, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 8>, - DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 3, GNDHWC, GKZYXC, Empty_Tuple, GNDHWK, BF16, BF16, F32, BF16, Empty_Tuple, BF16, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, GemmMNKPadding, 1, 128, 64, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 8>, - DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 3, GNDHWC, GKZYXC, Empty_Tuple, GNDHWK, BF16, BF16, F32, BF16, Empty_Tuple, BF16, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, GemmMNKPadding, 1, 64, 64, 64, 32, 8, 8, 32, 32, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 8>, - DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 3, GNDHWC, GKZYXC, Empty_Tuple, GNDHWK, BF16, BF16, F32, BF16, Empty_Tuple, BF16, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, GemmMNKPadding, 1, 256, 128, 64, 32, 8, 8, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, - DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 3, GNDHWC, GKZYXC, Empty_Tuple, GNDHWK, BF16, BF16, F32, BF16, Empty_Tuple, BF16, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, GemmMNKPadding, 1, 256, 64, 128, 32, 8, 8, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, - DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 3, GNDHWC, GKZYXC, Empty_Tuple, GNDHWK, BF16, BF16, F32, BF16, Empty_Tuple, BF16, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, GemmMNKPadding, 1, 128, 128, 32, 32, 8, 8, 32, 32, 2, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 8>, - DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 3, GNDHWC, GKZYXC, Empty_Tuple, GNDHWK, BF16, BF16, F32, BF16, Empty_Tuple, BF16, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, GemmMNKPadding, 1, 128, 32, 128, 32, 8, 8, 32, 32, 1, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 8>, - DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 3, GNDHWC, GKZYXC, Empty_Tuple, GNDHWK, BF16, BF16, F32, BF16, Empty_Tuple, BF16, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, GemmMNKPadding, 1, 64, 64, 32, 32, 8, 8, 32, 32, 2, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 8>, - DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 3, GNDHWC, GKZYXC, Empty_Tuple, GNDHWK, BF16, BF16, F32, BF16, Empty_Tuple, BF16, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, GemmMNKPadding, 1, 64, 32, 64, 32, 8, 8, 32, 32, 1, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 8> - // clang-format on - >; - void add_device_grouped_conv3d_fwd_xdl_gndhwc_gkzyxc_gndhwk_bf16_instances( std::vector>>& instances) { - add_device_operation_instances( - instances, device_grouped_conv3d_fwd_xdl_gndhwc_gkzyxc_gndhwk_bf16_instances{}); + add_device_operation_instances(instances, + device_grouped_conv_fwd_xdl_bf16_instances<3, + GNDHWC, + GKZYXC, + Empty_Tuple, + GNDHWK, + ConvFwdDefault>{}); + add_device_operation_instances(instances, + device_grouped_conv_fwd_xdl_bf16_instances<3, + GNDHWC, + GKZYXC, + Empty_Tuple, + GNDHWK, + ConvFwd1x1P0>{}); + add_device_operation_instances(instances, + device_grouped_conv_fwd_xdl_bf16_instances<3, + GNDHWC, + GKZYXC, + Empty_Tuple, + GNDHWK, + ConvFwd1x1S1P0>{}); } } // namespace instance diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/device_grouped_conv3d_fwd_xdl_gndhwc_gkzyxc_gndhwk_f16_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/device_grouped_conv3d_fwd_xdl_gndhwc_gkzyxc_gndhwk_f16_instance.cpp index 061674bd829840def99f4dcf317886844d398fea..fdd5c3169decf23e2fb7d5d660be4a8edfdf56a2 100644 --- a/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/device_grouped_conv3d_fwd_xdl_gndhwc_gkzyxc_gndhwk_f16_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/device_grouped_conv3d_fwd_xdl_gndhwc_gkzyxc_gndhwk_f16_instance.cpp @@ -1,15 +1,7 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. - -#include - -#include "ck/ck.hpp" -#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" -#include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_d_xdl_cshuffle.hpp" -#include "ck/tensor_operation/gpu/device/convolution_forward_specialization.hpp" -#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" -#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. +#include "ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_instance.hpp" #include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" namespace ck { @@ -17,94 +9,6 @@ namespace tensor_operation { namespace device { namespace instance { -using F16 = ck::half_t; -using F32 = float; - -using Empty_Tuple = ck::Tuple<>; - -template -using S = ck::Sequence; - -using GNDHWC = ck::tensor_layout::convolution::GNDHWC; -using GKZYXC = ck::tensor_layout::convolution::GKZYXC; -using GNDHWK = ck::tensor_layout::convolution::GNDHWK; - -using PassThrough = ck::tensor_operation::element_wise::PassThrough; - -static constexpr auto ConvFwdDefault = - ck::tensor_operation::device::ConvolutionForwardSpecialization::Default; - -static constexpr auto ConvFwd1x1P0 = - ck::tensor_operation::device::ConvolutionForwardSpecialization::Filter1x1Pad0; - -static constexpr auto ConvFwd1x1S1P0 = - ck::tensor_operation::device::ConvolutionForwardSpecialization::Filter1x1Stride1Pad0; - -static constexpr auto GemmMNKPadding = ck::tensor_operation::device::GemmSpecialization::MNKPadding; - -// in[g, n, di, hi, wi, c] * wei[g, k, z, y, x, c] = out[g, n, do, ho, wo, k] -using device_grouped_conv3d_fwd_xdl_gndhwc_gkzyxc_gndhwk_f16_instances = - std::tuple< - // clang-format off - // Default - //########################################| NumDim| A| B| Ds| E| AData| BData| AccData| CShuffle| Ds| EData| A| B| CDE| ConvForward| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| - //########################################| Spatial| Layout| Layout| Layout| Layout| Type| Type| Type| DataType| DataType| Type| Elementwise| Elementwise| Elementwise| Specialization| Specialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| - //########################################| | | | | | | | | | | | Operation| Operation| Operation| | | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| - //########################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | - DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 3, GNDHWC, GKZYXC, Empty_Tuple, GNDHWK, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, ConvFwdDefault, GemmMNKPadding, 1, 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, - DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 3, GNDHWC, GKZYXC, Empty_Tuple, GNDHWK, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, ConvFwdDefault, GemmMNKPadding, 1, 256, 128, 256, 32, 8, 8, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, - DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 3, GNDHWC, GKZYXC, Empty_Tuple, GNDHWK, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, ConvFwdDefault, GemmMNKPadding, 1, 128, 128, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 8>, - DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 3, GNDHWC, GKZYXC, Empty_Tuple, GNDHWK, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, ConvFwdDefault, GemmMNKPadding, 1, 256, 128, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, - DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 3, GNDHWC, GKZYXC, Empty_Tuple, GNDHWK, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, ConvFwdDefault, GemmMNKPadding, 1, 128, 128, 64, 32, 8, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 8>, - DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 3, GNDHWC, GKZYXC, Empty_Tuple, GNDHWK, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, ConvFwdDefault, GemmMNKPadding, 1, 128, 64, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 8>, - DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 3, GNDHWC, GKZYXC, Empty_Tuple, GNDHWK, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, ConvFwdDefault, GemmMNKPadding, 1, 64, 64, 64, 32, 8, 8, 32, 32, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 8>, - DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 3, GNDHWC, GKZYXC, Empty_Tuple, GNDHWK, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, ConvFwdDefault, GemmMNKPadding, 1, 256, 128, 64, 32, 8, 8, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, - DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 3, GNDHWC, GKZYXC, Empty_Tuple, GNDHWK, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, ConvFwdDefault, GemmMNKPadding, 1, 256, 64, 128, 32, 8, 8, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, - DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 3, GNDHWC, GKZYXC, Empty_Tuple, GNDHWK, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, ConvFwdDefault, GemmMNKPadding, 1, 128, 128, 32, 32, 8, 8, 32, 32, 2, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 8>, - DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 3, GNDHWC, GKZYXC, Empty_Tuple, GNDHWK, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, ConvFwdDefault, GemmMNKPadding, 1, 128, 32, 128, 32, 8, 8, 32, 32, 1, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 8>, - DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 3, GNDHWC, GKZYXC, Empty_Tuple, GNDHWK, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, ConvFwdDefault, GemmMNKPadding, 1, 64, 64, 32, 32, 8, 8, 32, 32, 2, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 8>, - DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 3, GNDHWC, GKZYXC, Empty_Tuple, GNDHWK, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, ConvFwdDefault, GemmMNKPadding, 1, 64, 32, 64, 32, 8, 8, 32, 32, 1, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 8>, - - // Filter1x1Pad0 - //########################################| NumDim| A| B| Ds| E| AData| BData| AccData| CShuffle| Ds| EData| A| B| CDE| ConvForward| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| - //########################################| Spatial| Layout| Layout| Layout| Layout| Type| Type| Type| DataType| DataType| Type| Elementwise| Elementwise| Elementwise| Specialization| Specialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| - //########################################| | | | | | | | | | | | Operation| Operation| Operation| | | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| - //########################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | - DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 3, GNDHWC, GKZYXC, Empty_Tuple, GNDHWK, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, GemmMNKPadding, 1, 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, - DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 3, GNDHWC, GKZYXC, Empty_Tuple, GNDHWK, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, GemmMNKPadding, 1, 256, 128, 256, 32, 8, 8, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, - DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 3, GNDHWC, GKZYXC, Empty_Tuple, GNDHWK, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, GemmMNKPadding, 1, 128, 128, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 8>, - DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 3, GNDHWC, GKZYXC, Empty_Tuple, GNDHWK, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, GemmMNKPadding, 1, 256, 128, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, - DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 3, GNDHWC, GKZYXC, Empty_Tuple, GNDHWK, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, GemmMNKPadding, 1, 128, 128, 64, 32, 8, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 8>, - DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 3, GNDHWC, GKZYXC, Empty_Tuple, GNDHWK, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, GemmMNKPadding, 1, 128, 64, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 8>, - DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 3, GNDHWC, GKZYXC, Empty_Tuple, GNDHWK, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, GemmMNKPadding, 1, 64, 64, 64, 32, 8, 8, 32, 32, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 8>, - DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 3, GNDHWC, GKZYXC, Empty_Tuple, GNDHWK, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, GemmMNKPadding, 1, 256, 128, 64, 32, 8, 8, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, - DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 3, GNDHWC, GKZYXC, Empty_Tuple, GNDHWK, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, GemmMNKPadding, 1, 256, 64, 128, 32, 8, 8, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, - DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 3, GNDHWC, GKZYXC, Empty_Tuple, GNDHWK, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, GemmMNKPadding, 1, 128, 128, 32, 32, 8, 8, 32, 32, 2, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 8>, - DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 3, GNDHWC, GKZYXC, Empty_Tuple, GNDHWK, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, GemmMNKPadding, 1, 128, 32, 128, 32, 8, 8, 32, 32, 1, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 8>, - DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 3, GNDHWC, GKZYXC, Empty_Tuple, GNDHWK, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, GemmMNKPadding, 1, 64, 64, 32, 32, 8, 8, 32, 32, 2, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 8>, - DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 3, GNDHWC, GKZYXC, Empty_Tuple, GNDHWK, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, GemmMNKPadding, 1, 64, 32, 64, 32, 8, 8, 32, 32, 1, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 8>, - - // Filter1x1Stride1Pad0 - //########################################| NumDim| A| B| Ds| E| AData| BData| AccData| CShuffle| Ds| EData| A| B| CDE| ConvForward| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| - //########################################| Spatial| Layout| Layout| Layout| Layout| Type| Type| Type| DataType| DataType| Type| Elementwise| Elementwise| Elementwise| Specialization| Specialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| - //########################################| | | | | | | | | | | | Operation| Operation| Operation| | | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| - //########################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | - DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 3, GNDHWC, GKZYXC, Empty_Tuple, GNDHWK, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, GemmMNKPadding, 1, 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, - DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 3, GNDHWC, GKZYXC, Empty_Tuple, GNDHWK, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, GemmMNKPadding, 1, 256, 128, 256, 32, 8, 8, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, - DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 3, GNDHWC, GKZYXC, Empty_Tuple, GNDHWK, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, GemmMNKPadding, 1, 128, 128, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 8>, - DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 3, GNDHWC, GKZYXC, Empty_Tuple, GNDHWK, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, GemmMNKPadding, 1, 256, 128, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, - DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 3, GNDHWC, GKZYXC, Empty_Tuple, GNDHWK, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, GemmMNKPadding, 1, 128, 128, 64, 32, 8, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 8>, - DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 3, GNDHWC, GKZYXC, Empty_Tuple, GNDHWK, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, GemmMNKPadding, 1, 128, 64, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 8>, - DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 3, GNDHWC, GKZYXC, Empty_Tuple, GNDHWK, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, GemmMNKPadding, 1, 64, 64, 64, 32, 8, 8, 32, 32, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 8>, - DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 3, GNDHWC, GKZYXC, Empty_Tuple, GNDHWK, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, GemmMNKPadding, 1, 256, 128, 64, 32, 8, 8, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, - DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 3, GNDHWC, GKZYXC, Empty_Tuple, GNDHWK, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, GemmMNKPadding, 1, 256, 64, 128, 32, 8, 8, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, - DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 3, GNDHWC, GKZYXC, Empty_Tuple, GNDHWK, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, GemmMNKPadding, 1, 128, 128, 32, 32, 8, 8, 32, 32, 2, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 8>, - DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 3, GNDHWC, GKZYXC, Empty_Tuple, GNDHWK, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, GemmMNKPadding, 1, 128, 32, 128, 32, 8, 8, 32, 32, 1, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 8>, - DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 3, GNDHWC, GKZYXC, Empty_Tuple, GNDHWK, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, GemmMNKPadding, 1, 64, 64, 32, 32, 8, 8, 32, 32, 2, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 8>, - DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 3, GNDHWC, GKZYXC, Empty_Tuple, GNDHWK, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, GemmMNKPadding, 1, 64, 32, 64, 32, 8, 8, 32, 32, 1, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 8> - // clang-format on - >; - void add_device_grouped_conv3d_fwd_xdl_gndhwc_gkzyxc_gndhwk_f16_instances( std::vector>>& instances) { - add_device_operation_instances( - instances, device_grouped_conv3d_fwd_xdl_gndhwc_gkzyxc_gndhwk_f16_instances{}); + add_device_operation_instances(instances, + device_grouped_conv_fwd_xdl_f16_instances<3, + GNDHWC, + GKZYXC, + Empty_Tuple, + GNDHWK, + ConvFwdDefault>{}); + add_device_operation_instances(instances, + device_grouped_conv_fwd_xdl_f16_instances<3, + GNDHWC, + GKZYXC, + Empty_Tuple, + GNDHWK, + ConvFwd1x1P0>{}); + add_device_operation_instances(instances, + device_grouped_conv_fwd_xdl_f16_instances<3, + GNDHWC, + GKZYXC, + Empty_Tuple, + GNDHWK, + ConvFwd1x1S1P0>{}); } } // namespace instance diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/device_grouped_conv3d_fwd_xdl_gndhwc_gkzyxc_gndhwk_f32_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/device_grouped_conv3d_fwd_xdl_gndhwc_gkzyxc_gndhwk_f32_instance.cpp index ed7e5476760983e8ca5a17475585ec5eb9e834f2..b486dd80b40b94adad20e25047ad82c29e547aa9 100644 --- a/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/device_grouped_conv3d_fwd_xdl_gndhwc_gkzyxc_gndhwk_f32_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/device_grouped_conv3d_fwd_xdl_gndhwc_gkzyxc_gndhwk_f32_instance.cpp @@ -1,15 +1,7 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. - -#include - -#include "ck/ck.hpp" -#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" -#include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_d_xdl_cshuffle.hpp" -#include "ck/tensor_operation/gpu/device/convolution_forward_specialization.hpp" -#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" -#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. +#include "ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_instance.hpp" #include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" namespace ck { @@ -17,93 +9,6 @@ namespace tensor_operation { namespace device { namespace instance { -using F32 = float; - -using Empty_Tuple = ck::Tuple<>; - -template -using S = ck::Sequence; - -using GNDHWC = ck::tensor_layout::convolution::GNDHWC; -using GKZYXC = ck::tensor_layout::convolution::GKZYXC; -using GNDHWK = ck::tensor_layout::convolution::GNDHWK; - -using PassThrough = ck::tensor_operation::element_wise::PassThrough; - -static constexpr auto ConvFwdDefault = - ck::tensor_operation::device::ConvolutionForwardSpecialization::Default; - -static constexpr auto ConvFwd1x1P0 = - ck::tensor_operation::device::ConvolutionForwardSpecialization::Filter1x1Pad0; - -static constexpr auto ConvFwd1x1S1P0 = - ck::tensor_operation::device::ConvolutionForwardSpecialization::Filter1x1Stride1Pad0; - -static constexpr auto GemmMNKPadding = ck::tensor_operation::device::GemmSpecialization::MNKPadding; - -// in[g, n, di, hi, wi, c] * wei[g, k, z, y, x, c] = out[g, n, do, ho, wo, k] -using device_grouped_conv3d_fwd_xdl_gndhwc_gkzyxc_gndhwk_f32_instances = - std::tuple< - // clang-format off - // Default - //########################################| NumDim| A| B| Ds| E| AData| BData| AccData| CShuffle| Ds| EData| A| B| CDE| ConvForward| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| - //########################################| Spatial| Layout| Layout| Layout| Layout| Type| Type| Type| DataType| DataType| Type| Elementwise| Elementwise| Elementwise| Specialization| Specialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| - //########################################| | | | | | | | | | | | Operation| Operation| Operation| | | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| - //########################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | - DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 3, GNDHWC, GKZYXC, Empty_Tuple, GNDHWK, F32, F32, F32, F32, Empty_Tuple, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, GemmMNKPadding, 1, 256, 256, 128, 16, 4, 4, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 16, 1, 16>, 4>, - DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 3, GNDHWC, GKZYXC, Empty_Tuple, GNDHWK, F32, F32, F32, F32, Empty_Tuple, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, GemmMNKPadding, 1, 256, 128, 256, 16, 4, 4, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 16, 1, 16>, 4>, - DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 3, GNDHWC, GKZYXC, Empty_Tuple, GNDHWK, F32, F32, F32, F32, Empty_Tuple, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, GemmMNKPadding, 1, 128, 128, 128, 16, 4, 4, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 8, 1, 16>, 4>, - DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 3, GNDHWC, GKZYXC, Empty_Tuple, GNDHWK, F32, F32, F32, F32, Empty_Tuple, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, GemmMNKPadding, 1, 256, 128, 128, 16, 4, 4, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 16, 1, 16>, 4>, - DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 3, GNDHWC, GKZYXC, Empty_Tuple, GNDHWK, F32, F32, F32, F32, Empty_Tuple, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, GemmMNKPadding, 1, 128, 128, 64, 16, 4, 4, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 16, 1, 8>, 4>, - DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 3, GNDHWC, GKZYXC, Empty_Tuple, GNDHWK, F32, F32, F32, F32, Empty_Tuple, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, GemmMNKPadding, 1, 128, 64, 128, 16, 4, 4, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 8, 1, 16>, 4>, - DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 3, GNDHWC, GKZYXC, Empty_Tuple, GNDHWK, F32, F32, F32, F32, Empty_Tuple, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, GemmMNKPadding, 1, 64, 64, 64, 16, 4, 4, 32, 32, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 8, 1, 8>, 4>, - DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 3, GNDHWC, GKZYXC, Empty_Tuple, GNDHWK, F32, F32, F32, F32, Empty_Tuple, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, GemmMNKPadding, 1, 256, 128, 64, 16, 4, 4, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 16, 1, 16>, 4>, - DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 3, GNDHWC, GKZYXC, Empty_Tuple, GNDHWK, F32, F32, F32, F32, Empty_Tuple, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, GemmMNKPadding, 1, 256, 64, 128, 16, 4, 4, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 16, 1, 16>, 4>, - DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 3, GNDHWC, GKZYXC, Empty_Tuple, GNDHWK, F32, F32, F32, F32, Empty_Tuple, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, GemmMNKPadding, 1, 128, 128, 32, 16, 4, 4, 32, 32, 2, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 16, 1, 8>, 4>, - DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 3, GNDHWC, GKZYXC, Empty_Tuple, GNDHWK, F32, F32, F32, F32, Empty_Tuple, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, GemmMNKPadding, 1, 128, 32, 128, 16, 4, 4, 32, 32, 1, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 8, 1, 16>, 4>, - DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 3, GNDHWC, GKZYXC, Empty_Tuple, GNDHWK, F32, F32, F32, F32, Empty_Tuple, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, GemmMNKPadding, 1, 64, 64, 32, 16, 4, 4, 32, 32, 2, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 8, 1, 8>, 4>, - DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 3, GNDHWC, GKZYXC, Empty_Tuple, GNDHWK, F32, F32, F32, F32, Empty_Tuple, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, GemmMNKPadding, 1, 64, 32, 64, 16, 4, 4, 32, 32, 1, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 8, 1, 8>, 4>, - - // Filter1x1Pad0 - //########################################| NumDim| A| B| Ds| E| AData| BData| AccData| CShuffle| Ds| EData| A| B| CDE| ConvForward| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| - //########################################| Spatial| Layout| Layout| Layout| Layout| Type| Type| Type| DataType| DataType| Type| Elementwise| Elementwise| Elementwise| Specialization| Specialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| - //########################################| | | | | | | | | | | | Operation| Operation| Operation| | | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| - //########################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | - DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 3, GNDHWC, GKZYXC, Empty_Tuple, GNDHWK, F32, F32, F32, F32, Empty_Tuple, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, GemmMNKPadding, 1, 256, 256, 128, 16, 4, 4, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 16, 1, 16>, 4>, - DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 3, GNDHWC, GKZYXC, Empty_Tuple, GNDHWK, F32, F32, F32, F32, Empty_Tuple, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, GemmMNKPadding, 1, 256, 128, 256, 16, 4, 4, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 16, 1, 16>, 4>, - DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 3, GNDHWC, GKZYXC, Empty_Tuple, GNDHWK, F32, F32, F32, F32, Empty_Tuple, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, GemmMNKPadding, 1, 128, 128, 128, 16, 4, 4, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 8, 1, 16>, 4>, - DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 3, GNDHWC, GKZYXC, Empty_Tuple, GNDHWK, F32, F32, F32, F32, Empty_Tuple, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, GemmMNKPadding, 1, 256, 128, 128, 16, 4, 4, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 16, 1, 16>, 4>, - DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 3, GNDHWC, GKZYXC, Empty_Tuple, GNDHWK, F32, F32, F32, F32, Empty_Tuple, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, GemmMNKPadding, 1, 128, 128, 64, 16, 4, 4, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 16, 1, 8>, 4>, - DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 3, GNDHWC, GKZYXC, Empty_Tuple, GNDHWK, F32, F32, F32, F32, Empty_Tuple, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, GemmMNKPadding, 1, 128, 64, 128, 16, 4, 4, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 8, 1, 16>, 4>, - DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 3, GNDHWC, GKZYXC, Empty_Tuple, GNDHWK, F32, F32, F32, F32, Empty_Tuple, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, GemmMNKPadding, 1, 64, 64, 64, 16, 4, 4, 32, 32, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 8, 1, 8>, 4>, - DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 3, GNDHWC, GKZYXC, Empty_Tuple, GNDHWK, F32, F32, F32, F32, Empty_Tuple, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, GemmMNKPadding, 1, 256, 128, 64, 16, 4, 4, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 16, 1, 16>, 4>, - DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 3, GNDHWC, GKZYXC, Empty_Tuple, GNDHWK, F32, F32, F32, F32, Empty_Tuple, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, GemmMNKPadding, 1, 256, 64, 128, 16, 4, 4, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 16, 1, 16>, 4>, - DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 3, GNDHWC, GKZYXC, Empty_Tuple, GNDHWK, F32, F32, F32, F32, Empty_Tuple, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, GemmMNKPadding, 1, 128, 128, 32, 16, 4, 4, 32, 32, 2, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 16, 1, 8>, 4>, - DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 3, GNDHWC, GKZYXC, Empty_Tuple, GNDHWK, F32, F32, F32, F32, Empty_Tuple, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, GemmMNKPadding, 1, 128, 32, 128, 16, 4, 4, 32, 32, 1, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 8, 1, 16>, 4>, - DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 3, GNDHWC, GKZYXC, Empty_Tuple, GNDHWK, F32, F32, F32, F32, Empty_Tuple, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, GemmMNKPadding, 1, 64, 64, 32, 16, 4, 4, 32, 32, 2, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 8, 1, 8>, 4>, - DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 3, GNDHWC, GKZYXC, Empty_Tuple, GNDHWK, F32, F32, F32, F32, Empty_Tuple, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, GemmMNKPadding, 1, 64, 32, 64, 16, 4, 4, 32, 32, 1, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 8, 1, 8>, 4>, - - // Filter1x1Stride1Pad0 - //########################################| NumDim| A| B| Ds| E| AData| BData| AccData| CShuffle| Ds| EData| A| B| CDE| ConvForward| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| - //########################################| Spatial| Layout| Layout| Layout| Layout| Type| Type| Type| DataType| DataType| Type| Elementwise| Elementwise| Elementwise| Specialization| Specialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| - //########################################| | | | | | | | | | | | Operation| Operation| Operation| | | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| - //########################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | - DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 3, GNDHWC, GKZYXC, Empty_Tuple, GNDHWK, F32, F32, F32, F32, Empty_Tuple, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, GemmMNKPadding, 1, 256, 256, 128, 16, 4, 4, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 16, 1, 16>, 4>, - DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 3, GNDHWC, GKZYXC, Empty_Tuple, GNDHWK, F32, F32, F32, F32, Empty_Tuple, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, GemmMNKPadding, 1, 256, 128, 256, 16, 4, 4, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 16, 1, 16>, 4>, - DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 3, GNDHWC, GKZYXC, Empty_Tuple, GNDHWK, F32, F32, F32, F32, Empty_Tuple, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, GemmMNKPadding, 1, 128, 128, 128, 16, 4, 4, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 8, 1, 16>, 4>, - DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 3, GNDHWC, GKZYXC, Empty_Tuple, GNDHWK, F32, F32, F32, F32, Empty_Tuple, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, GemmMNKPadding, 1, 256, 128, 128, 16, 4, 4, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 16, 1, 16>, 4>, - DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 3, GNDHWC, GKZYXC, Empty_Tuple, GNDHWK, F32, F32, F32, F32, Empty_Tuple, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, GemmMNKPadding, 1, 128, 128, 64, 16, 4, 4, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 16, 1, 8>, 4>, - DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 3, GNDHWC, GKZYXC, Empty_Tuple, GNDHWK, F32, F32, F32, F32, Empty_Tuple, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, GemmMNKPadding, 1, 128, 64, 128, 16, 4, 4, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 8, 1, 16>, 4>, - DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 3, GNDHWC, GKZYXC, Empty_Tuple, GNDHWK, F32, F32, F32, F32, Empty_Tuple, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, GemmMNKPadding, 1, 64, 64, 64, 16, 4, 4, 32, 32, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 8, 1, 8>, 4>, - DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 3, GNDHWC, GKZYXC, Empty_Tuple, GNDHWK, F32, F32, F32, F32, Empty_Tuple, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, GemmMNKPadding, 1, 256, 128, 64, 16, 4, 4, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 16, 1, 16>, 4>, - DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 3, GNDHWC, GKZYXC, Empty_Tuple, GNDHWK, F32, F32, F32, F32, Empty_Tuple, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, GemmMNKPadding, 1, 256, 64, 128, 16, 4, 4, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 16, 1, 16>, 4>, - DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 3, GNDHWC, GKZYXC, Empty_Tuple, GNDHWK, F32, F32, F32, F32, Empty_Tuple, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, GemmMNKPadding, 1, 128, 128, 32, 16, 4, 4, 32, 32, 2, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 16, 1, 8>, 4>, - DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 3, GNDHWC, GKZYXC, Empty_Tuple, GNDHWK, F32, F32, F32, F32, Empty_Tuple, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, GemmMNKPadding, 1, 128, 32, 128, 16, 4, 4, 32, 32, 1, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 8, 1, 16>, 4>, - DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 3, GNDHWC, GKZYXC, Empty_Tuple, GNDHWK, F32, F32, F32, F32, Empty_Tuple, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, GemmMNKPadding, 1, 64, 64, 32, 16, 4, 4, 32, 32, 2, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 8, 1, 8>, 4>, - DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 3, GNDHWC, GKZYXC, Empty_Tuple, GNDHWK, F32, F32, F32, F32, Empty_Tuple, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, GemmMNKPadding, 1, 64, 32, 64, 16, 4, 4, 32, 32, 1, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 8, 1, 8>, 4> - // clang-format on - >; - void add_device_grouped_conv3d_fwd_xdl_gndhwc_gkzyxc_gndhwk_f32_instances( std::vector>>& instances) { - add_device_operation_instances( - instances, device_grouped_conv3d_fwd_xdl_gndhwc_gkzyxc_gndhwk_f32_instances{}); + add_device_operation_instances(instances, + device_grouped_conv_fwd_xdl_f32_instances<3, + GNDHWC, + GKZYXC, + Empty_Tuple, + GNDHWK, + ConvFwdDefault>{}); + add_device_operation_instances(instances, + device_grouped_conv_fwd_xdl_f32_instances<3, + GNDHWC, + GKZYXC, + Empty_Tuple, + GNDHWK, + ConvFwd1x1P0>{}); + add_device_operation_instances(instances, + device_grouped_conv_fwd_xdl_f32_instances<3, + GNDHWC, + GKZYXC, + Empty_Tuple, + GNDHWK, + ConvFwd1x1S1P0>{}); } } // namespace instance diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/device_grouped_conv3d_fwd_xdl_gndhwc_gkzyxc_gndhwk_int8_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/device_grouped_conv3d_fwd_xdl_gndhwc_gkzyxc_gndhwk_int8_instance.cpp index bf5fa306013e485e6dd57faa011bee8217001148..de3d8664775c649fb44cab839194f1649b661b98 100644 --- a/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/device_grouped_conv3d_fwd_xdl_gndhwc_gkzyxc_gndhwk_int8_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/device_grouped_conv3d_fwd_xdl_gndhwc_gkzyxc_gndhwk_int8_instance.cpp @@ -1,15 +1,7 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. - -#include - -#include "ck/ck.hpp" -#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" -#include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_d_xdl_cshuffle.hpp" -#include "ck/tensor_operation/gpu/device/convolution_forward_specialization.hpp" -#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" -#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. +#include "ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_instance.hpp" #include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" namespace ck { @@ -17,90 +9,6 @@ namespace tensor_operation { namespace device { namespace instance { -using Empty_Tuple = ck::Tuple<>; - -template -using S = ck::Sequence; - -using GNDHWC = ck::tensor_layout::convolution::GNDHWC; -using GKZYXC = ck::tensor_layout::convolution::GKZYXC; -using GNDHWK = ck::tensor_layout::convolution::GNDHWK; - -using PassThrough = ck::tensor_operation::element_wise::PassThrough; - -static constexpr auto ConvFwdDefault = - ck::tensor_operation::device::ConvolutionForwardSpecialization::Default; - -static constexpr auto ConvFwd1x1P0 = - ck::tensor_operation::device::ConvolutionForwardSpecialization::Filter1x1Pad0; - -static constexpr auto ConvFwd1x1S1P0 = - ck::tensor_operation::device::ConvolutionForwardSpecialization::Filter1x1Stride1Pad0; - -static constexpr auto GemmMNKPadding = ck::tensor_operation::device::GemmSpecialization::MNKPadding; - -// in[g, n, di, hi, wi, c] * wei[g, k, z, y, x, c] = out[g, n, do, ho, wo, k] -using device_grouped_conv3d_fwd_xdl_gndhwc_gkzyxc_gndhwk_int8_instances = std::tuple< - // clang-format off - // Default - //########################################| NumDim| A| B| Ds| E| AData| BData| AccData| CShuffle| Ds| EData| A| B| CDE| ConvForward| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| - //########################################| Spatial| Layout| Layout| Layout| Layout| Type| Type| Type| DataType| DataType| Type| Elementwise| Elementwise| Elementwise| Specialization| Specialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| - //########################################| | | | | | | | | | | | Operation| Operation| Operation| | | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| - //########################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | - DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 3, GNDHWC, GKZYXC, Empty_Tuple, GNDHWK, int8_t, int8_t, int32_t, int8_t, Empty_Tuple, int8_t, PassThrough, PassThrough, PassThrough, ConvFwdDefault, GemmMNKPadding, 1, 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, - DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 3, GNDHWC, GKZYXC, Empty_Tuple, GNDHWK, int8_t, int8_t, int32_t, int8_t, Empty_Tuple, int8_t, PassThrough, PassThrough, PassThrough, ConvFwdDefault, GemmMNKPadding, 1, 256, 128, 256, 32, 8, 8, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, - DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 3, GNDHWC, GKZYXC, Empty_Tuple, GNDHWK, int8_t, int8_t, int32_t, int8_t, Empty_Tuple, int8_t, PassThrough, PassThrough, PassThrough, ConvFwdDefault, GemmMNKPadding, 1, 128, 128, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 8>, - DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 3, GNDHWC, GKZYXC, Empty_Tuple, GNDHWK, int8_t, int8_t, int32_t, int8_t, Empty_Tuple, int8_t, PassThrough, PassThrough, PassThrough, ConvFwdDefault, GemmMNKPadding, 1, 256, 128, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, - DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 3, GNDHWC, GKZYXC, Empty_Tuple, GNDHWK, int8_t, int8_t, int32_t, int8_t, Empty_Tuple, int8_t, PassThrough, PassThrough, PassThrough, ConvFwdDefault, GemmMNKPadding, 1, 128, 128, 64, 32, 8, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 8>, - DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 3, GNDHWC, GKZYXC, Empty_Tuple, GNDHWK, int8_t, int8_t, int32_t, int8_t, Empty_Tuple, int8_t, PassThrough, PassThrough, PassThrough, ConvFwdDefault, GemmMNKPadding, 1, 128, 64, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 8>, - DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 3, GNDHWC, GKZYXC, Empty_Tuple, GNDHWK, int8_t, int8_t, int32_t, int8_t, Empty_Tuple, int8_t, PassThrough, PassThrough, PassThrough, ConvFwdDefault, GemmMNKPadding, 1, 64, 64, 64, 32, 8, 8, 32, 32, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 8>, - DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 3, GNDHWC, GKZYXC, Empty_Tuple, GNDHWK, int8_t, int8_t, int32_t, int8_t, Empty_Tuple, int8_t, PassThrough, PassThrough, PassThrough, ConvFwdDefault, GemmMNKPadding, 1, 256, 128, 64, 32, 8, 8, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, - DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 3, GNDHWC, GKZYXC, Empty_Tuple, GNDHWK, int8_t, int8_t, int32_t, int8_t, Empty_Tuple, int8_t, PassThrough, PassThrough, PassThrough, ConvFwdDefault, GemmMNKPadding, 1, 256, 64, 128, 32, 8, 8, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, - DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 3, GNDHWC, GKZYXC, Empty_Tuple, GNDHWK, int8_t, int8_t, int32_t, int8_t, Empty_Tuple, int8_t, PassThrough, PassThrough, PassThrough, ConvFwdDefault, GemmMNKPadding, 1, 128, 128, 32, 32, 8, 8, 32, 32, 2, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 8>, - DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 3, GNDHWC, GKZYXC, Empty_Tuple, GNDHWK, int8_t, int8_t, int32_t, int8_t, Empty_Tuple, int8_t, PassThrough, PassThrough, PassThrough, ConvFwdDefault, GemmMNKPadding, 1, 128, 32, 128, 32, 8, 8, 32, 32, 1, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 8>, - DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 3, GNDHWC, GKZYXC, Empty_Tuple, GNDHWK, int8_t, int8_t, int32_t, int8_t, Empty_Tuple, int8_t, PassThrough, PassThrough, PassThrough, ConvFwdDefault, GemmMNKPadding, 1, 64, 64, 32, 32, 8, 8, 32, 32, 2, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 8>, - DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 3, GNDHWC, GKZYXC, Empty_Tuple, GNDHWK, int8_t, int8_t, int32_t, int8_t, Empty_Tuple, int8_t, PassThrough, PassThrough, PassThrough, ConvFwdDefault, GemmMNKPadding, 1, 64, 32, 64, 32, 8, 8, 32, 32, 1, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 8>, - - // Filter1x1Pad0 - //########################################| NumDim| A| B| Ds| E| AData| BData| AccData| CShuffle| Ds| EData| A| B| CDE| ConvForward| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| - //########################################| Spatial| Layout| Layout| Layout| Layout| Type| Type| Type| DataType| DataType| Type| Elementwise| Elementwise| Elementwise| Specialization| Specialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| - //########################################| | | | | | | | | | | | Operation| Operation| Operation| | | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| - //########################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | - DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 3, GNDHWC, GKZYXC, Empty_Tuple, GNDHWK, int8_t, int8_t, int32_t, int8_t, Empty_Tuple, int8_t, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, GemmMNKPadding, 1, 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, - DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 3, GNDHWC, GKZYXC, Empty_Tuple, GNDHWK, int8_t, int8_t, int32_t, int8_t, Empty_Tuple, int8_t, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, GemmMNKPadding, 1, 256, 128, 256, 32, 8, 8, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, - DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 3, GNDHWC, GKZYXC, Empty_Tuple, GNDHWK, int8_t, int8_t, int32_t, int8_t, Empty_Tuple, int8_t, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, GemmMNKPadding, 1, 128, 128, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 8>, - DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 3, GNDHWC, GKZYXC, Empty_Tuple, GNDHWK, int8_t, int8_t, int32_t, int8_t, Empty_Tuple, int8_t, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, GemmMNKPadding, 1, 256, 128, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, - DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 3, GNDHWC, GKZYXC, Empty_Tuple, GNDHWK, int8_t, int8_t, int32_t, int8_t, Empty_Tuple, int8_t, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, GemmMNKPadding, 1, 128, 128, 64, 32, 8, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 8>, - DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 3, GNDHWC, GKZYXC, Empty_Tuple, GNDHWK, int8_t, int8_t, int32_t, int8_t, Empty_Tuple, int8_t, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, GemmMNKPadding, 1, 128, 64, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 8>, - DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 3, GNDHWC, GKZYXC, Empty_Tuple, GNDHWK, int8_t, int8_t, int32_t, int8_t, Empty_Tuple, int8_t, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, GemmMNKPadding, 1, 64, 64, 64, 32, 8, 8, 32, 32, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 8>, - DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 3, GNDHWC, GKZYXC, Empty_Tuple, GNDHWK, int8_t, int8_t, int32_t, int8_t, Empty_Tuple, int8_t, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, GemmMNKPadding, 1, 256, 128, 64, 32, 8, 8, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, - DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 3, GNDHWC, GKZYXC, Empty_Tuple, GNDHWK, int8_t, int8_t, int32_t, int8_t, Empty_Tuple, int8_t, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, GemmMNKPadding, 1, 256, 64, 128, 32, 8, 8, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, - DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 3, GNDHWC, GKZYXC, Empty_Tuple, GNDHWK, int8_t, int8_t, int32_t, int8_t, Empty_Tuple, int8_t, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, GemmMNKPadding, 1, 128, 128, 32, 32, 8, 8, 32, 32, 2, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 8>, - DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 3, GNDHWC, GKZYXC, Empty_Tuple, GNDHWK, int8_t, int8_t, int32_t, int8_t, Empty_Tuple, int8_t, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, GemmMNKPadding, 1, 128, 32, 128, 32, 8, 8, 32, 32, 1, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 8>, - DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 3, GNDHWC, GKZYXC, Empty_Tuple, GNDHWK, int8_t, int8_t, int32_t, int8_t, Empty_Tuple, int8_t, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, GemmMNKPadding, 1, 64, 64, 32, 32, 8, 8, 32, 32, 2, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 8>, - DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 3, GNDHWC, GKZYXC, Empty_Tuple, GNDHWK, int8_t, int8_t, int32_t, int8_t, Empty_Tuple, int8_t, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, GemmMNKPadding, 1, 64, 32, 64, 32, 8, 8, 32, 32, 1, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 8>, - - // Filter1x1Stride1Pad0 - //########################################| NumDim| A| B| Ds| E| AData| BData| AccData| CShuffle| Ds| EData| A| B| CDE| ConvForward| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| - //########################################| Spatial| Layout| Layout| Layout| Layout| Type| Type| Type| DataType| DataType| Type| Elementwise| Elementwise| Elementwise| Specialization| Specialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| - //########################################| | | | | | | | | | | | Operation| Operation| Operation| | | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| - //########################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | - DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 3, GNDHWC, GKZYXC, Empty_Tuple, GNDHWK, int8_t, int8_t, int32_t, int8_t, Empty_Tuple, int8_t, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, GemmMNKPadding, 1, 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, - DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 3, GNDHWC, GKZYXC, Empty_Tuple, GNDHWK, int8_t, int8_t, int32_t, int8_t, Empty_Tuple, int8_t, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, GemmMNKPadding, 1, 256, 128, 256, 32, 8, 8, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, - DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 3, GNDHWC, GKZYXC, Empty_Tuple, GNDHWK, int8_t, int8_t, int32_t, int8_t, Empty_Tuple, int8_t, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, GemmMNKPadding, 1, 128, 128, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 8>, - DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 3, GNDHWC, GKZYXC, Empty_Tuple, GNDHWK, int8_t, int8_t, int32_t, int8_t, Empty_Tuple, int8_t, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, GemmMNKPadding, 1, 256, 128, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, - DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 3, GNDHWC, GKZYXC, Empty_Tuple, GNDHWK, int8_t, int8_t, int32_t, int8_t, Empty_Tuple, int8_t, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, GemmMNKPadding, 1, 128, 128, 64, 32, 8, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 8>, - DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 3, GNDHWC, GKZYXC, Empty_Tuple, GNDHWK, int8_t, int8_t, int32_t, int8_t, Empty_Tuple, int8_t, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, GemmMNKPadding, 1, 128, 64, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 8>, - DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 3, GNDHWC, GKZYXC, Empty_Tuple, GNDHWK, int8_t, int8_t, int32_t, int8_t, Empty_Tuple, int8_t, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, GemmMNKPadding, 1, 64, 64, 64, 32, 8, 8, 32, 32, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 8>, - DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 3, GNDHWC, GKZYXC, Empty_Tuple, GNDHWK, int8_t, int8_t, int32_t, int8_t, Empty_Tuple, int8_t, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, GemmMNKPadding, 1, 256, 128, 64, 32, 8, 8, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, - DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 3, GNDHWC, GKZYXC, Empty_Tuple, GNDHWK, int8_t, int8_t, int32_t, int8_t, Empty_Tuple, int8_t, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, GemmMNKPadding, 1, 256, 64, 128, 32, 8, 8, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, - DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 3, GNDHWC, GKZYXC, Empty_Tuple, GNDHWK, int8_t, int8_t, int32_t, int8_t, Empty_Tuple, int8_t, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, GemmMNKPadding, 1, 128, 128, 32, 32, 8, 8, 32, 32, 2, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 8>, - DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 3, GNDHWC, GKZYXC, Empty_Tuple, GNDHWK, int8_t, int8_t, int32_t, int8_t, Empty_Tuple, int8_t, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, GemmMNKPadding, 1, 128, 32, 128, 32, 8, 8, 32, 32, 1, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 8>, - DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 3, GNDHWC, GKZYXC, Empty_Tuple, GNDHWK, int8_t, int8_t, int32_t, int8_t, Empty_Tuple, int8_t, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, GemmMNKPadding, 1, 64, 64, 32, 32, 8, 8, 32, 32, 2, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 8>, - DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 3, GNDHWC, GKZYXC, Empty_Tuple, GNDHWK, int8_t, int8_t, int32_t, int8_t, Empty_Tuple, int8_t, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, GemmMNKPadding, 1, 64, 32, 64, 32, 8, 8, 32, 32, 1, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 8> - // clang-format on - >; - void add_device_grouped_conv3d_fwd_xdl_gndhwc_gkzyxc_gndhwk_int8_instances( std::vector>>& instances) { - add_device_operation_instances( - instances, device_grouped_conv3d_fwd_xdl_gndhwc_gkzyxc_gndhwk_int8_instances{}); + add_device_operation_instances(instances, + device_grouped_conv_fwd_xdl_int8_instances<3, + GNDHWC, + GKZYXC, + Empty_Tuple, + GNDHWK, + ConvFwdDefault>{}); + add_device_operation_instances(instances, + device_grouped_conv_fwd_xdl_int8_instances<3, + GNDHWC, + GKZYXC, + Empty_Tuple, + GNDHWK, + ConvFwd1x1P0>{}); + add_device_operation_instances(instances, + device_grouped_conv_fwd_xdl_int8_instances<3, + GNDHWC, + GKZYXC, + Empty_Tuple, + GNDHWK, + ConvFwd1x1S1P0>{}); } } // namespace instance diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_instance.cpp new file mode 100644 index 0000000000000000000000000000000000000000..b3c4c1d9c96a60d8d54cc6d73084316cc7fa6bd5 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_instance.cpp @@ -0,0 +1,52 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#include "ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_instance.hpp" +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +void add_device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_instances( + std::vector>>& instances) +{ + add_device_operation_instances(instances, + device_grouped_conv_fwd_xdl_bf16_instances<3, + NDHWGC, + GKZYXC, + Empty_Tuple, + NDHWGK, + ConvFwdDefault>{}); + add_device_operation_instances(instances, + device_grouped_conv_fwd_xdl_bf16_instances<3, + NDHWGC, + GKZYXC, + Empty_Tuple, + NDHWGK, + ConvFwd1x1P0>{}); + add_device_operation_instances(instances, + device_grouped_conv_fwd_xdl_bf16_instances<3, + NDHWGC, + GKZYXC, + Empty_Tuple, + NDHWGK, + ConvFwd1x1S1P0>{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_f16_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_f16_instance.cpp new file mode 100644 index 0000000000000000000000000000000000000000..d252c7513fe6fb5336b2cf9691b06b14ec72e647 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_f16_instance.cpp @@ -0,0 +1,52 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#include "ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_instance.hpp" +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +void add_device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_f16_instances( + std::vector>>& instances) +{ + add_device_operation_instances(instances, + device_grouped_conv_fwd_xdl_f16_instances<3, + NDHWGC, + GKZYXC, + Empty_Tuple, + NDHWGK, + ConvFwdDefault>{}); + add_device_operation_instances(instances, + device_grouped_conv_fwd_xdl_f16_instances<3, + NDHWGC, + GKZYXC, + Empty_Tuple, + NDHWGK, + ConvFwd1x1P0>{}); + add_device_operation_instances(instances, + device_grouped_conv_fwd_xdl_f16_instances<3, + NDHWGC, + GKZYXC, + Empty_Tuple, + NDHWGK, + ConvFwd1x1S1P0>{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_f32_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_f32_instance.cpp new file mode 100644 index 0000000000000000000000000000000000000000..0d79c1c08af6490c686675d583cd779fef84f45d --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_f32_instance.cpp @@ -0,0 +1,52 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#include "ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_instance.hpp" +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +void add_device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_f32_instances( + std::vector>>& instances) +{ + add_device_operation_instances(instances, + device_grouped_conv_fwd_xdl_f32_instances<3, + NDHWGC, + GKZYXC, + Empty_Tuple, + NDHWGK, + ConvFwdDefault>{}); + add_device_operation_instances(instances, + device_grouped_conv_fwd_xdl_f32_instances<3, + NDHWGC, + GKZYXC, + Empty_Tuple, + NDHWGK, + ConvFwd1x1P0>{}); + add_device_operation_instances(instances, + device_grouped_conv_fwd_xdl_f32_instances<3, + NDHWGC, + GKZYXC, + Empty_Tuple, + NDHWGK, + ConvFwd1x1S1P0>{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_int8_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_int8_instance.cpp new file mode 100644 index 0000000000000000000000000000000000000000..881b6df5c549a587c767638a9330eb3e70ae1cc9 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_int8_instance.cpp @@ -0,0 +1,51 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#include "ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_instance.hpp" +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { +void add_device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_int8_instances( + std::vector>>& instances) +{ + add_device_operation_instances(instances, + device_grouped_conv_fwd_xdl_int8_instances<3, + NDHWGC, + GKZYXC, + Empty_Tuple, + NDHWGK, + ConvFwdDefault>{}); + add_device_operation_instances(instances, + device_grouped_conv_fwd_xdl_int8_instances<3, + NDHWGC, + GKZYXC, + Empty_Tuple, + NDHWGK, + ConvFwd1x1P0>{}); + add_device_operation_instances(instances, + device_grouped_conv_fwd_xdl_int8_instances<3, + NDHWGC, + GKZYXC, + Empty_Tuple, + NDHWGK, + ConvFwd1x1S1P0>{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/device_grouped_conv3d_fwd_xdl_ndhwgc_kzyxgc_ndhwgk_bf16_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/device_grouped_conv3d_fwd_xdl_ndhwgc_kzyxgc_ndhwgk_bf16_instance.cpp deleted file mode 100644 index 8c384937352a012a701a21a3bb572d6b6fb7e57c..0000000000000000000000000000000000000000 --- a/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/device_grouped_conv3d_fwd_xdl_ndhwgc_kzyxgc_ndhwgk_bf16_instance.cpp +++ /dev/null @@ -1,129 +0,0 @@ -// SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. - -#include - -#include "ck/ck.hpp" -#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" -#include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_d_xdl_cshuffle.hpp" -#include "ck/tensor_operation/gpu/device/convolution_forward_specialization.hpp" -#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" -#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" - -#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" - -namespace ck { -namespace tensor_operation { -namespace device { -namespace instance { - -using BF16 = ck::bhalf_t; -using F32 = float; - -using Empty_Tuple = ck::Tuple<>; - -template -using S = ck::Sequence; - -using NDHWGC = ck::tensor_layout::convolution::NDHWGC; -using KZYXGC = ck::tensor_layout::convolution::KZYXGC; -using NDHWGK = ck::tensor_layout::convolution::NDHWGK; - -using PassThrough = ck::tensor_operation::element_wise::PassThrough; - -static constexpr auto ConvFwdDefault = - ck::tensor_operation::device::ConvolutionForwardSpecialization::Default; - -static constexpr auto ConvFwd1x1P0 = - ck::tensor_operation::device::ConvolutionForwardSpecialization::Filter1x1Pad0; - -static constexpr auto ConvFwd1x1S1P0 = - ck::tensor_operation::device::ConvolutionForwardSpecialization::Filter1x1Stride1Pad0; - -static constexpr auto GemmMNKPadding = ck::tensor_operation::device::GemmSpecialization::MNKPadding; - -// in[g, n, di, hi, wi, c] * wei[g, k, z, y, x, c] = out[g, n, do, ho, wo, k] -using device_grouped_conv3d_fwd_xdl_ndhwgc_kzyxgc_ndhwgk_bf16_instances = - std::tuple< - // clang-format off - // Default - //########################################| NumDim| A| B| Ds| E| AData| BData| AccData| CShuffle| Ds| EData| A| B| CDE| ConvForward| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| - //########################################| Spatial| Layout| Layout| Layout| Layout| Type| Type| Type| DataType| DataType| Type| Elementwise| Elementwise| Elementwise| Specialization| Specialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| - //########################################| | | | | | | | | | | | Operation| Operation| Operation| | | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| - //########################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | - DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 3, NDHWGC, KZYXGC, Empty_Tuple, NDHWGK, BF16, BF16, F32, BF16, Empty_Tuple, BF16, PassThrough, PassThrough, PassThrough, ConvFwdDefault, GemmMNKPadding, 1, 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, - DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 3, NDHWGC, KZYXGC, Empty_Tuple, NDHWGK, BF16, BF16, F32, BF16, Empty_Tuple, BF16, PassThrough, PassThrough, PassThrough, ConvFwdDefault, GemmMNKPadding, 1, 256, 128, 256, 32, 8, 8, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, - DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 3, NDHWGC, KZYXGC, Empty_Tuple, NDHWGK, BF16, BF16, F32, BF16, Empty_Tuple, BF16, PassThrough, PassThrough, PassThrough, ConvFwdDefault, GemmMNKPadding, 1, 128, 128, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 8>, - DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 3, NDHWGC, KZYXGC, Empty_Tuple, NDHWGK, BF16, BF16, F32, BF16, Empty_Tuple, BF16, PassThrough, PassThrough, PassThrough, ConvFwdDefault, GemmMNKPadding, 1, 256, 128, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, - DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 3, NDHWGC, KZYXGC, Empty_Tuple, NDHWGK, BF16, BF16, F32, BF16, Empty_Tuple, BF16, PassThrough, PassThrough, PassThrough, ConvFwdDefault, GemmMNKPadding, 1, 128, 128, 64, 32, 8, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 8>, - DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 3, NDHWGC, KZYXGC, Empty_Tuple, NDHWGK, BF16, BF16, F32, BF16, Empty_Tuple, BF16, PassThrough, PassThrough, PassThrough, ConvFwdDefault, GemmMNKPadding, 1, 128, 64, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 8>, - DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 3, NDHWGC, KZYXGC, Empty_Tuple, NDHWGK, BF16, BF16, F32, BF16, Empty_Tuple, BF16, PassThrough, PassThrough, PassThrough, ConvFwdDefault, GemmMNKPadding, 1, 64, 64, 64, 32, 8, 8, 32, 32, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 8>, - DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 3, NDHWGC, KZYXGC, Empty_Tuple, NDHWGK, BF16, BF16, F32, BF16, Empty_Tuple, BF16, PassThrough, PassThrough, PassThrough, ConvFwdDefault, GemmMNKPadding, 1, 256, 128, 64, 32, 8, 8, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, - DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 3, NDHWGC, KZYXGC, Empty_Tuple, NDHWGK, BF16, BF16, F32, BF16, Empty_Tuple, BF16, PassThrough, PassThrough, PassThrough, ConvFwdDefault, GemmMNKPadding, 1, 256, 64, 128, 32, 8, 8, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, - DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 3, NDHWGC, KZYXGC, Empty_Tuple, NDHWGK, BF16, BF16, F32, BF16, Empty_Tuple, BF16, PassThrough, PassThrough, PassThrough, ConvFwdDefault, GemmMNKPadding, 1, 128, 128, 32, 32, 8, 8, 32, 32, 2, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 8>, - DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 3, NDHWGC, KZYXGC, Empty_Tuple, NDHWGK, BF16, BF16, F32, BF16, Empty_Tuple, BF16, PassThrough, PassThrough, PassThrough, ConvFwdDefault, GemmMNKPadding, 1, 128, 32, 128, 32, 8, 8, 32, 32, 1, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 8>, - DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 3, NDHWGC, KZYXGC, Empty_Tuple, NDHWGK, BF16, BF16, F32, BF16, Empty_Tuple, BF16, PassThrough, PassThrough, PassThrough, ConvFwdDefault, GemmMNKPadding, 1, 64, 64, 32, 32, 8, 8, 32, 32, 2, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 8>, - DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 3, NDHWGC, KZYXGC, Empty_Tuple, NDHWGK, BF16, BF16, F32, BF16, Empty_Tuple, BF16, PassThrough, PassThrough, PassThrough, ConvFwdDefault, GemmMNKPadding, 1, 64, 32, 64, 32, 8, 8, 32, 32, 1, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 8>, - - // Filter1x1Pad0 - //########################################| NumDim| A| B| Ds| E| AData| BData| AccData| CShuffle| Ds| EData| A| B| CDE| ConvForward| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| - //########################################| Spatial| Layout| Layout| Layout| Layout| Type| Type| Type| DataType| DataType| Type| Elementwise| Elementwise| Elementwise| Specialization| Specialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| - //########################################| | | | | | | | | | | | Operation| Operation| Operation| | | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| - //########################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | - DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 3, NDHWGC, KZYXGC, Empty_Tuple, NDHWGK, BF16, BF16, F32, BF16, Empty_Tuple, BF16, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, GemmMNKPadding, 1, 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, - DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 3, NDHWGC, KZYXGC, Empty_Tuple, NDHWGK, BF16, BF16, F32, BF16, Empty_Tuple, BF16, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, GemmMNKPadding, 1, 256, 128, 256, 32, 8, 8, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, - DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 3, NDHWGC, KZYXGC, Empty_Tuple, NDHWGK, BF16, BF16, F32, BF16, Empty_Tuple, BF16, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, GemmMNKPadding, 1, 128, 128, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 8>, - DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 3, NDHWGC, KZYXGC, Empty_Tuple, NDHWGK, BF16, BF16, F32, BF16, Empty_Tuple, BF16, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, GemmMNKPadding, 1, 256, 128, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, - DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 3, NDHWGC, KZYXGC, Empty_Tuple, NDHWGK, BF16, BF16, F32, BF16, Empty_Tuple, BF16, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, GemmMNKPadding, 1, 128, 128, 64, 32, 8, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 8>, - DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 3, NDHWGC, KZYXGC, Empty_Tuple, NDHWGK, BF16, BF16, F32, BF16, Empty_Tuple, BF16, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, GemmMNKPadding, 1, 128, 64, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 8>, - DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 3, NDHWGC, KZYXGC, Empty_Tuple, NDHWGK, BF16, BF16, F32, BF16, Empty_Tuple, BF16, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, GemmMNKPadding, 1, 64, 64, 64, 32, 8, 8, 32, 32, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 8>, - DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 3, NDHWGC, KZYXGC, Empty_Tuple, NDHWGK, BF16, BF16, F32, BF16, Empty_Tuple, BF16, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, GemmMNKPadding, 1, 256, 128, 64, 32, 8, 8, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, - DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 3, NDHWGC, KZYXGC, Empty_Tuple, NDHWGK, BF16, BF16, F32, BF16, Empty_Tuple, BF16, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, GemmMNKPadding, 1, 256, 64, 128, 32, 8, 8, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, - DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 3, NDHWGC, KZYXGC, Empty_Tuple, NDHWGK, BF16, BF16, F32, BF16, Empty_Tuple, BF16, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, GemmMNKPadding, 1, 128, 128, 32, 32, 8, 8, 32, 32, 2, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 8>, - DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 3, NDHWGC, KZYXGC, Empty_Tuple, NDHWGK, BF16, BF16, F32, BF16, Empty_Tuple, BF16, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, GemmMNKPadding, 1, 128, 32, 128, 32, 8, 8, 32, 32, 1, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 8>, - DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 3, NDHWGC, KZYXGC, Empty_Tuple, NDHWGK, BF16, BF16, F32, BF16, Empty_Tuple, BF16, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, GemmMNKPadding, 1, 64, 64, 32, 32, 8, 8, 32, 32, 2, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 8>, - DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 3, NDHWGC, KZYXGC, Empty_Tuple, NDHWGK, BF16, BF16, F32, BF16, Empty_Tuple, BF16, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, GemmMNKPadding, 1, 64, 32, 64, 32, 8, 8, 32, 32, 1, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 8>, - - // Filter1x1Stride1Pad0 - //########################################| NumDim| A| B| Ds| E| AData| BData| AccData| CShuffle| Ds| EData| A| B| CDE| ConvForward| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| - //########################################| Spatial| Layout| Layout| Layout| Layout| Type| Type| Type| DataType| DataType| Type| Elementwise| Elementwise| Elementwise| Specialization| Specialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| - //########################################| | | | | | | | | | | | Operation| Operation| Operation| | | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| - //########################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | - DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 3, NDHWGC, KZYXGC, Empty_Tuple, NDHWGK, BF16, BF16, F32, BF16, Empty_Tuple, BF16, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, GemmMNKPadding, 1, 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, - DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 3, NDHWGC, KZYXGC, Empty_Tuple, NDHWGK, BF16, BF16, F32, BF16, Empty_Tuple, BF16, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, GemmMNKPadding, 1, 256, 128, 256, 32, 8, 8, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, - DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 3, NDHWGC, KZYXGC, Empty_Tuple, NDHWGK, BF16, BF16, F32, BF16, Empty_Tuple, BF16, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, GemmMNKPadding, 1, 128, 128, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 8>, - DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 3, NDHWGC, KZYXGC, Empty_Tuple, NDHWGK, BF16, BF16, F32, BF16, Empty_Tuple, BF16, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, GemmMNKPadding, 1, 256, 128, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, - DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 3, NDHWGC, KZYXGC, Empty_Tuple, NDHWGK, BF16, BF16, F32, BF16, Empty_Tuple, BF16, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, GemmMNKPadding, 1, 128, 128, 64, 32, 8, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 8>, - DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 3, NDHWGC, KZYXGC, Empty_Tuple, NDHWGK, BF16, BF16, F32, BF16, Empty_Tuple, BF16, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, GemmMNKPadding, 1, 128, 64, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 8>, - DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 3, NDHWGC, KZYXGC, Empty_Tuple, NDHWGK, BF16, BF16, F32, BF16, Empty_Tuple, BF16, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, GemmMNKPadding, 1, 64, 64, 64, 32, 8, 8, 32, 32, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 8>, - DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 3, NDHWGC, KZYXGC, Empty_Tuple, NDHWGK, BF16, BF16, F32, BF16, Empty_Tuple, BF16, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, GemmMNKPadding, 1, 256, 128, 64, 32, 8, 8, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, - DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 3, NDHWGC, KZYXGC, Empty_Tuple, NDHWGK, BF16, BF16, F32, BF16, Empty_Tuple, BF16, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, GemmMNKPadding, 1, 256, 64, 128, 32, 8, 8, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, - DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 3, NDHWGC, KZYXGC, Empty_Tuple, NDHWGK, BF16, BF16, F32, BF16, Empty_Tuple, BF16, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, GemmMNKPadding, 1, 128, 128, 32, 32, 8, 8, 32, 32, 2, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 8>, - DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 3, NDHWGC, KZYXGC, Empty_Tuple, NDHWGK, BF16, BF16, F32, BF16, Empty_Tuple, BF16, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, GemmMNKPadding, 1, 128, 32, 128, 32, 8, 8, 32, 32, 1, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 8>, - DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 3, NDHWGC, KZYXGC, Empty_Tuple, NDHWGK, BF16, BF16, F32, BF16, Empty_Tuple, BF16, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, GemmMNKPadding, 1, 64, 64, 32, 32, 8, 8, 32, 32, 2, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 8>, - DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 3, NDHWGC, KZYXGC, Empty_Tuple, NDHWGK, BF16, BF16, F32, BF16, Empty_Tuple, BF16, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, GemmMNKPadding, 1, 64, 32, 64, 32, 8, 8, 32, 32, 1, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 8> - // clang-format on - >; - -void add_device_grouped_conv3d_fwd_xdl_ndhwgc_kzyxgc_ndhwgk_bf16_instances( - std::vector>>& instances) -{ - add_device_operation_instances( - instances, device_grouped_conv3d_fwd_xdl_ndhwgc_kzyxgc_ndhwgk_bf16_instances{}); -} - -} // namespace instance -} // namespace device -} // namespace tensor_operation -} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/device_grouped_conv3d_fwd_xdl_ndhwgc_kzyxgc_ndhwgk_f16_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/device_grouped_conv3d_fwd_xdl_ndhwgc_kzyxgc_ndhwgk_f16_instance.cpp deleted file mode 100644 index 487cd22721a086bbeca32dd56639614903596905..0000000000000000000000000000000000000000 --- a/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/device_grouped_conv3d_fwd_xdl_ndhwgc_kzyxgc_ndhwgk_f16_instance.cpp +++ /dev/null @@ -1,129 +0,0 @@ -// SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. - -#include - -#include "ck/ck.hpp" -#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" -#include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_d_xdl_cshuffle.hpp" -#include "ck/tensor_operation/gpu/device/convolution_forward_specialization.hpp" -#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" -#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" - -#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" - -namespace ck { -namespace tensor_operation { -namespace device { -namespace instance { - -using F16 = ck::half_t; -using F32 = float; - -using Empty_Tuple = ck::Tuple<>; - -template -using S = ck::Sequence; - -using NDHWGC = ck::tensor_layout::convolution::NDHWGC; -using KZYXGC = ck::tensor_layout::convolution::KZYXGC; -using NDHWGK = ck::tensor_layout::convolution::NDHWGK; - -using PassThrough = ck::tensor_operation::element_wise::PassThrough; - -static constexpr auto ConvFwdDefault = - ck::tensor_operation::device::ConvolutionForwardSpecialization::Default; - -static constexpr auto ConvFwd1x1P0 = - ck::tensor_operation::device::ConvolutionForwardSpecialization::Filter1x1Pad0; - -static constexpr auto ConvFwd1x1S1P0 = - ck::tensor_operation::device::ConvolutionForwardSpecialization::Filter1x1Stride1Pad0; - -static constexpr auto GemmMNKPadding = ck::tensor_operation::device::GemmSpecialization::MNKPadding; - -// in[g, n, di, hi, wi, c] * wei[g, k, z, y, x, c] = out[g, n, do, ho, wo, k] -using device_grouped_conv3d_fwd_xdl_ndhwgc_kzyxgc_ndhwgk_f16_instances = - std::tuple< - // clang-format off - // Default - //########################################| NumDim| A| B| Ds| E| AData| BData| AccData| CShuffle| Ds| EData| A| B| CDE| ConvForward| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| - //########################################| Spatial| Layout| Layout| Layout| Layout| Type| Type| Type| DataType| DataType| Type| Elementwise| Elementwise| Elementwise| Specialization| Specialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| - //########################################| | | | | | | | | | | | Operation| Operation| Operation| | | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| - //########################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | - DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 3, NDHWGC, KZYXGC, Empty_Tuple, NDHWGK, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, ConvFwdDefault, GemmMNKPadding, 1, 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, - DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 3, NDHWGC, KZYXGC, Empty_Tuple, NDHWGK, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, ConvFwdDefault, GemmMNKPadding, 1, 256, 128, 256, 32, 8, 8, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, - DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 3, NDHWGC, KZYXGC, Empty_Tuple, NDHWGK, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, ConvFwdDefault, GemmMNKPadding, 1, 128, 128, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 8>, - DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 3, NDHWGC, KZYXGC, Empty_Tuple, NDHWGK, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, ConvFwdDefault, GemmMNKPadding, 1, 256, 128, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, - DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 3, NDHWGC, KZYXGC, Empty_Tuple, NDHWGK, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, ConvFwdDefault, GemmMNKPadding, 1, 128, 128, 64, 32, 8, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 8>, - DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 3, NDHWGC, KZYXGC, Empty_Tuple, NDHWGK, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, ConvFwdDefault, GemmMNKPadding, 1, 128, 64, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 8>, - DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 3, NDHWGC, KZYXGC, Empty_Tuple, NDHWGK, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, ConvFwdDefault, GemmMNKPadding, 1, 64, 64, 64, 32, 8, 8, 32, 32, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 8>, - DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 3, NDHWGC, KZYXGC, Empty_Tuple, NDHWGK, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, ConvFwdDefault, GemmMNKPadding, 1, 256, 128, 64, 32, 8, 8, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, - DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 3, NDHWGC, KZYXGC, Empty_Tuple, NDHWGK, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, ConvFwdDefault, GemmMNKPadding, 1, 256, 64, 128, 32, 8, 8, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, - DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 3, NDHWGC, KZYXGC, Empty_Tuple, NDHWGK, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, ConvFwdDefault, GemmMNKPadding, 1, 128, 128, 32, 32, 8, 8, 32, 32, 2, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 8>, - DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 3, NDHWGC, KZYXGC, Empty_Tuple, NDHWGK, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, ConvFwdDefault, GemmMNKPadding, 1, 128, 32, 128, 32, 8, 8, 32, 32, 1, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 8>, - DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 3, NDHWGC, KZYXGC, Empty_Tuple, NDHWGK, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, ConvFwdDefault, GemmMNKPadding, 1, 64, 64, 32, 32, 8, 8, 32, 32, 2, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 8>, - DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 3, NDHWGC, KZYXGC, Empty_Tuple, NDHWGK, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, ConvFwdDefault, GemmMNKPadding, 1, 64, 32, 64, 32, 8, 8, 32, 32, 1, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 8>, - - // Filter1x1Pad0 - //########################################| NumDim| A| B| Ds| E| AData| BData| AccData| CShuffle| Ds| EData| A| B| CDE| ConvForward| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| - //########################################| Spatial| Layout| Layout| Layout| Layout| Type| Type| Type| DataType| DataType| Type| Elementwise| Elementwise| Elementwise| Specialization| Specialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| - //########################################| | | | | | | | | | | | Operation| Operation| Operation| | | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| - //########################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | - DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 3, NDHWGC, KZYXGC, Empty_Tuple, NDHWGK, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, GemmMNKPadding, 1, 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, - DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 3, NDHWGC, KZYXGC, Empty_Tuple, NDHWGK, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, GemmMNKPadding, 1, 256, 128, 256, 32, 8, 8, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, - DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 3, NDHWGC, KZYXGC, Empty_Tuple, NDHWGK, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, GemmMNKPadding, 1, 128, 128, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 8>, - DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 3, NDHWGC, KZYXGC, Empty_Tuple, NDHWGK, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, GemmMNKPadding, 1, 256, 128, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, - DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 3, NDHWGC, KZYXGC, Empty_Tuple, NDHWGK, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, GemmMNKPadding, 1, 128, 128, 64, 32, 8, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 8>, - DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 3, NDHWGC, KZYXGC, Empty_Tuple, NDHWGK, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, GemmMNKPadding, 1, 128, 64, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 8>, - DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 3, NDHWGC, KZYXGC, Empty_Tuple, NDHWGK, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, GemmMNKPadding, 1, 64, 64, 64, 32, 8, 8, 32, 32, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 8>, - DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 3, NDHWGC, KZYXGC, Empty_Tuple, NDHWGK, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, GemmMNKPadding, 1, 256, 128, 64, 32, 8, 8, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, - DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 3, NDHWGC, KZYXGC, Empty_Tuple, NDHWGK, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, GemmMNKPadding, 1, 256, 64, 128, 32, 8, 8, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, - DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 3, NDHWGC, KZYXGC, Empty_Tuple, NDHWGK, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, GemmMNKPadding, 1, 128, 128, 32, 32, 8, 8, 32, 32, 2, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 8>, - DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 3, NDHWGC, KZYXGC, Empty_Tuple, NDHWGK, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, GemmMNKPadding, 1, 128, 32, 128, 32, 8, 8, 32, 32, 1, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 8>, - DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 3, NDHWGC, KZYXGC, Empty_Tuple, NDHWGK, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, GemmMNKPadding, 1, 64, 64, 32, 32, 8, 8, 32, 32, 2, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 8>, - DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 3, NDHWGC, KZYXGC, Empty_Tuple, NDHWGK, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, GemmMNKPadding, 1, 64, 32, 64, 32, 8, 8, 32, 32, 1, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 8>, - - // Filter1x1Stride1Pad0 - //########################################| NumDim| A| B| Ds| E| AData| BData| AccData| CShuffle| Ds| EData| A| B| CDE| ConvForward| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| - //########################################| Spatial| Layout| Layout| Layout| Layout| Type| Type| Type| DataType| DataType| Type| Elementwise| Elementwise| Elementwise| Specialization| Specialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| - //########################################| | | | | | | | | | | | Operation| Operation| Operation| | | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| - //########################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | - DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 3, NDHWGC, KZYXGC, Empty_Tuple, NDHWGK, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, GemmMNKPadding, 1, 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, - DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 3, NDHWGC, KZYXGC, Empty_Tuple, NDHWGK, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, GemmMNKPadding, 1, 256, 128, 256, 32, 8, 8, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, - DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 3, NDHWGC, KZYXGC, Empty_Tuple, NDHWGK, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, GemmMNKPadding, 1, 128, 128, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 8>, - DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 3, NDHWGC, KZYXGC, Empty_Tuple, NDHWGK, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, GemmMNKPadding, 1, 256, 128, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, - DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 3, NDHWGC, KZYXGC, Empty_Tuple, NDHWGK, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, GemmMNKPadding, 1, 128, 128, 64, 32, 8, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 8>, - DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 3, NDHWGC, KZYXGC, Empty_Tuple, NDHWGK, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, GemmMNKPadding, 1, 128, 64, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 8>, - DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 3, NDHWGC, KZYXGC, Empty_Tuple, NDHWGK, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, GemmMNKPadding, 1, 64, 64, 64, 32, 8, 8, 32, 32, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 8>, - DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 3, NDHWGC, KZYXGC, Empty_Tuple, NDHWGK, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, GemmMNKPadding, 1, 256, 128, 64, 32, 8, 8, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, - DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 3, NDHWGC, KZYXGC, Empty_Tuple, NDHWGK, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, GemmMNKPadding, 1, 256, 64, 128, 32, 8, 8, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, - DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 3, NDHWGC, KZYXGC, Empty_Tuple, NDHWGK, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, GemmMNKPadding, 1, 128, 128, 32, 32, 8, 8, 32, 32, 2, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 8>, - DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 3, NDHWGC, KZYXGC, Empty_Tuple, NDHWGK, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, GemmMNKPadding, 1, 128, 32, 128, 32, 8, 8, 32, 32, 1, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 8>, - DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 3, NDHWGC, KZYXGC, Empty_Tuple, NDHWGK, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, GemmMNKPadding, 1, 64, 64, 32, 32, 8, 8, 32, 32, 2, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 8>, - DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 3, NDHWGC, KZYXGC, Empty_Tuple, NDHWGK, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, GemmMNKPadding, 1, 64, 32, 64, 32, 8, 8, 32, 32, 1, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 8> - // clang-format on - >; - -void add_device_grouped_conv3d_fwd_xdl_ndhwgc_kzyxgc_ndhwgk_f16_instances( - std::vector>>& instances) -{ - add_device_operation_instances( - instances, device_grouped_conv3d_fwd_xdl_ndhwgc_kzyxgc_ndhwgk_f16_instances{}); -} - -} // namespace instance -} // namespace device -} // namespace tensor_operation -} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/device_grouped_conv3d_fwd_xdl_ndhwgc_kzyxgc_ndhwgk_f32_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/device_grouped_conv3d_fwd_xdl_ndhwgc_kzyxgc_ndhwgk_f32_instance.cpp deleted file mode 100644 index d497cd57edfd85eda86d9641f49fa219b45e72d5..0000000000000000000000000000000000000000 --- a/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/device_grouped_conv3d_fwd_xdl_ndhwgc_kzyxgc_ndhwgk_f32_instance.cpp +++ /dev/null @@ -1,128 +0,0 @@ -// SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. - -#include - -#include "ck/ck.hpp" -#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" -#include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_d_xdl_cshuffle.hpp" -#include "ck/tensor_operation/gpu/device/convolution_forward_specialization.hpp" -#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" -#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" - -#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" - -namespace ck { -namespace tensor_operation { -namespace device { -namespace instance { - -using F32 = float; - -using Empty_Tuple = ck::Tuple<>; - -template -using S = ck::Sequence; - -using NDHWGC = ck::tensor_layout::convolution::NDHWGC; -using KZYXGC = ck::tensor_layout::convolution::KZYXGC; -using NDHWGK = ck::tensor_layout::convolution::NDHWGK; - -using PassThrough = ck::tensor_operation::element_wise::PassThrough; - -static constexpr auto ConvFwdDefault = - ck::tensor_operation::device::ConvolutionForwardSpecialization::Default; - -static constexpr auto ConvFwd1x1P0 = - ck::tensor_operation::device::ConvolutionForwardSpecialization::Filter1x1Pad0; - -static constexpr auto ConvFwd1x1S1P0 = - ck::tensor_operation::device::ConvolutionForwardSpecialization::Filter1x1Stride1Pad0; - -static constexpr auto GemmMNKPadding = ck::tensor_operation::device::GemmSpecialization::MNKPadding; - -// in[g, n, di, hi, wi, c] * wei[g, k, z, y, x, c] = out[g, n, do, ho, wo, k] -using device_grouped_conv3d_fwd_xdl_ndhwgc_kzyxgc_ndhwgk_f32_instances = - std::tuple< - // clang-format off - // Default - //########################################| NumDim| A| B| Ds| E| AData| BData| AccData| CShuffle| Ds| EData| A| B| CDE| ConvForward| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| - //########################################| Spatial| Layout| Layout| Layout| Layout| Type| Type| Type| DataType| DataType| Type| Elementwise| Elementwise| Elementwise| Specialization| Specialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| - //########################################| | | | | | | | | | | | Operation| Operation| Operation| | | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| - //########################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | - DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 3, NDHWGC, KZYXGC, Empty_Tuple, NDHWGK, F32, F32, F32, F32, Empty_Tuple, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, GemmMNKPadding, 1, 256, 256, 128, 16, 4, 4, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 16, 1, 16>, 4>, - DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 3, NDHWGC, KZYXGC, Empty_Tuple, NDHWGK, F32, F32, F32, F32, Empty_Tuple, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, GemmMNKPadding, 1, 256, 128, 256, 16, 4, 4, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 16, 1, 16>, 4>, - DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 3, NDHWGC, KZYXGC, Empty_Tuple, NDHWGK, F32, F32, F32, F32, Empty_Tuple, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, GemmMNKPadding, 1, 128, 128, 128, 16, 4, 4, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 8, 1, 16>, 4>, - DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 3, NDHWGC, KZYXGC, Empty_Tuple, NDHWGK, F32, F32, F32, F32, Empty_Tuple, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, GemmMNKPadding, 1, 256, 128, 128, 16, 4, 4, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 16, 1, 16>, 4>, - DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 3, NDHWGC, KZYXGC, Empty_Tuple, NDHWGK, F32, F32, F32, F32, Empty_Tuple, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, GemmMNKPadding, 1, 128, 128, 64, 16, 4, 4, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 16, 1, 8>, 4>, - DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 3, NDHWGC, KZYXGC, Empty_Tuple, NDHWGK, F32, F32, F32, F32, Empty_Tuple, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, GemmMNKPadding, 1, 128, 64, 128, 16, 4, 4, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 8, 1, 16>, 4>, - DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 3, NDHWGC, KZYXGC, Empty_Tuple, NDHWGK, F32, F32, F32, F32, Empty_Tuple, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, GemmMNKPadding, 1, 64, 64, 64, 16, 4, 4, 32, 32, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 8, 1, 8>, 4>, - DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 3, NDHWGC, KZYXGC, Empty_Tuple, NDHWGK, F32, F32, F32, F32, Empty_Tuple, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, GemmMNKPadding, 1, 256, 128, 64, 16, 4, 4, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 16, 1, 16>, 4>, - DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 3, NDHWGC, KZYXGC, Empty_Tuple, NDHWGK, F32, F32, F32, F32, Empty_Tuple, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, GemmMNKPadding, 1, 256, 64, 128, 16, 4, 4, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 16, 1, 16>, 4>, - DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 3, NDHWGC, KZYXGC, Empty_Tuple, NDHWGK, F32, F32, F32, F32, Empty_Tuple, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, GemmMNKPadding, 1, 128, 128, 32, 16, 4, 4, 32, 32, 2, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 16, 1, 8>, 4>, - DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 3, NDHWGC, KZYXGC, Empty_Tuple, NDHWGK, F32, F32, F32, F32, Empty_Tuple, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, GemmMNKPadding, 1, 128, 32, 128, 16, 4, 4, 32, 32, 1, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 8, 1, 16>, 4>, - DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 3, NDHWGC, KZYXGC, Empty_Tuple, NDHWGK, F32, F32, F32, F32, Empty_Tuple, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, GemmMNKPadding, 1, 64, 64, 32, 16, 4, 4, 32, 32, 2, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 8, 1, 8>, 4>, - DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 3, NDHWGC, KZYXGC, Empty_Tuple, NDHWGK, F32, F32, F32, F32, Empty_Tuple, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, GemmMNKPadding, 1, 64, 32, 64, 16, 4, 4, 32, 32, 1, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 8, 1, 8>, 4>, - - // Filter1x1Pad0 - //########################################| NumDim| A| B| Ds| E| AData| BData| AccData| CShuffle| Ds| EData| A| B| CDE| ConvForward| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| - //########################################| Spatial| Layout| Layout| Layout| Layout| Type| Type| Type| DataType| DataType| Type| Elementwise| Elementwise| Elementwise| Specialization| Specialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| - //########################################| | | | | | | | | | | | Operation| Operation| Operation| | | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| - //########################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | - DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 3, NDHWGC, KZYXGC, Empty_Tuple, NDHWGK, F32, F32, F32, F32, Empty_Tuple, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, GemmMNKPadding, 1, 256, 256, 128, 16, 4, 4, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 16, 1, 16>, 4>, - DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 3, NDHWGC, KZYXGC, Empty_Tuple, NDHWGK, F32, F32, F32, F32, Empty_Tuple, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, GemmMNKPadding, 1, 256, 128, 256, 16, 4, 4, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 16, 1, 16>, 4>, - DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 3, NDHWGC, KZYXGC, Empty_Tuple, NDHWGK, F32, F32, F32, F32, Empty_Tuple, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, GemmMNKPadding, 1, 128, 128, 128, 16, 4, 4, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 8, 1, 16>, 4>, - DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 3, NDHWGC, KZYXGC, Empty_Tuple, NDHWGK, F32, F32, F32, F32, Empty_Tuple, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, GemmMNKPadding, 1, 256, 128, 128, 16, 4, 4, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 16, 1, 16>, 4>, - DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 3, NDHWGC, KZYXGC, Empty_Tuple, NDHWGK, F32, F32, F32, F32, Empty_Tuple, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, GemmMNKPadding, 1, 128, 128, 64, 16, 4, 4, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 16, 1, 8>, 4>, - DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 3, NDHWGC, KZYXGC, Empty_Tuple, NDHWGK, F32, F32, F32, F32, Empty_Tuple, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, GemmMNKPadding, 1, 128, 64, 128, 16, 4, 4, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 8, 1, 16>, 4>, - DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 3, NDHWGC, KZYXGC, Empty_Tuple, NDHWGK, F32, F32, F32, F32, Empty_Tuple, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, GemmMNKPadding, 1, 64, 64, 64, 16, 4, 4, 32, 32, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 8, 1, 8>, 4>, - DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 3, NDHWGC, KZYXGC, Empty_Tuple, NDHWGK, F32, F32, F32, F32, Empty_Tuple, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, GemmMNKPadding, 1, 256, 128, 64, 16, 4, 4, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 16, 1, 16>, 4>, - DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 3, NDHWGC, KZYXGC, Empty_Tuple, NDHWGK, F32, F32, F32, F32, Empty_Tuple, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, GemmMNKPadding, 1, 256, 64, 128, 16, 4, 4, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 16, 1, 16>, 4>, - DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 3, NDHWGC, KZYXGC, Empty_Tuple, NDHWGK, F32, F32, F32, F32, Empty_Tuple, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, GemmMNKPadding, 1, 128, 128, 32, 16, 4, 4, 32, 32, 2, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 16, 1, 8>, 4>, - DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 3, NDHWGC, KZYXGC, Empty_Tuple, NDHWGK, F32, F32, F32, F32, Empty_Tuple, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, GemmMNKPadding, 1, 128, 32, 128, 16, 4, 4, 32, 32, 1, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 8, 1, 16>, 4>, - DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 3, NDHWGC, KZYXGC, Empty_Tuple, NDHWGK, F32, F32, F32, F32, Empty_Tuple, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, GemmMNKPadding, 1, 64, 64, 32, 16, 4, 4, 32, 32, 2, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 8, 1, 8>, 4>, - DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 3, NDHWGC, KZYXGC, Empty_Tuple, NDHWGK, F32, F32, F32, F32, Empty_Tuple, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, GemmMNKPadding, 1, 64, 32, 64, 16, 4, 4, 32, 32, 1, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 8, 1, 8>, 4>, - - // Filter1x1Stride1Pad0 - //########################################| NumDim| A| B| Ds| E| AData| BData| AccData| CShuffle| Ds| EData| A| B| CDE| ConvForward| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| - //########################################| Spatial| Layout| Layout| Layout| Layout| Type| Type| Type| DataType| DataType| Type| Elementwise| Elementwise| Elementwise| Specialization| Specialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| - //########################################| | | | | | | | | | | | Operation| Operation| Operation| | | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| - //########################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | - DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 3, NDHWGC, KZYXGC, Empty_Tuple, NDHWGK, F32, F32, F32, F32, Empty_Tuple, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, GemmMNKPadding, 1, 256, 256, 128, 16, 4, 4, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 16, 1, 16>, 4>, - DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 3, NDHWGC, KZYXGC, Empty_Tuple, NDHWGK, F32, F32, F32, F32, Empty_Tuple, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, GemmMNKPadding, 1, 256, 128, 256, 16, 4, 4, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 16, 1, 16>, 4>, - DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 3, NDHWGC, KZYXGC, Empty_Tuple, NDHWGK, F32, F32, F32, F32, Empty_Tuple, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, GemmMNKPadding, 1, 128, 128, 128, 16, 4, 4, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 8, 1, 16>, 4>, - DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 3, NDHWGC, KZYXGC, Empty_Tuple, NDHWGK, F32, F32, F32, F32, Empty_Tuple, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, GemmMNKPadding, 1, 256, 128, 128, 16, 4, 4, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 16, 1, 16>, 4>, - DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 3, NDHWGC, KZYXGC, Empty_Tuple, NDHWGK, F32, F32, F32, F32, Empty_Tuple, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, GemmMNKPadding, 1, 128, 128, 64, 16, 4, 4, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 16, 1, 8>, 4>, - DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 3, NDHWGC, KZYXGC, Empty_Tuple, NDHWGK, F32, F32, F32, F32, Empty_Tuple, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, GemmMNKPadding, 1, 128, 64, 128, 16, 4, 4, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 8, 1, 16>, 4>, - DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 3, NDHWGC, KZYXGC, Empty_Tuple, NDHWGK, F32, F32, F32, F32, Empty_Tuple, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, GemmMNKPadding, 1, 64, 64, 64, 16, 4, 4, 32, 32, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 8, 1, 8>, 4>, - DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 3, NDHWGC, KZYXGC, Empty_Tuple, NDHWGK, F32, F32, F32, F32, Empty_Tuple, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, GemmMNKPadding, 1, 256, 128, 64, 16, 4, 4, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 16, 1, 16>, 4>, - DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 3, NDHWGC, KZYXGC, Empty_Tuple, NDHWGK, F32, F32, F32, F32, Empty_Tuple, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, GemmMNKPadding, 1, 256, 64, 128, 16, 4, 4, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 16, 1, 16>, 4>, - DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 3, NDHWGC, KZYXGC, Empty_Tuple, NDHWGK, F32, F32, F32, F32, Empty_Tuple, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, GemmMNKPadding, 1, 128, 128, 32, 16, 4, 4, 32, 32, 2, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 16, 1, 8>, 4>, - DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 3, NDHWGC, KZYXGC, Empty_Tuple, NDHWGK, F32, F32, F32, F32, Empty_Tuple, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, GemmMNKPadding, 1, 128, 32, 128, 16, 4, 4, 32, 32, 1, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 8, 1, 16>, 4>, - DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 3, NDHWGC, KZYXGC, Empty_Tuple, NDHWGK, F32, F32, F32, F32, Empty_Tuple, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, GemmMNKPadding, 1, 64, 64, 32, 16, 4, 4, 32, 32, 2, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 8, 1, 8>, 4>, - DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 3, NDHWGC, KZYXGC, Empty_Tuple, NDHWGK, F32, F32, F32, F32, Empty_Tuple, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, GemmMNKPadding, 1, 64, 32, 64, 16, 4, 4, 32, 32, 1, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 8, 1, 8>, 4> - // clang-format on - >; - -void add_device_grouped_conv3d_fwd_xdl_ndhwgc_kzyxgc_ndhwgk_f32_instances( - std::vector>>& instances) -{ - add_device_operation_instances( - instances, device_grouped_conv3d_fwd_xdl_ndhwgc_kzyxgc_ndhwgk_f32_instances{}); -} - -} // namespace instance -} // namespace device -} // namespace tensor_operation -} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/device_grouped_conv3d_fwd_xdl_ndhwgc_kzyxgc_ndhwgk_int8_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/device_grouped_conv3d_fwd_xdl_ndhwgc_kzyxgc_ndhwgk_int8_instance.cpp deleted file mode 100644 index 2e53fbbda5c5e9d4f2712b17fce2f57f7d140842..0000000000000000000000000000000000000000 --- a/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/device_grouped_conv3d_fwd_xdl_ndhwgc_kzyxgc_ndhwgk_int8_instance.cpp +++ /dev/null @@ -1,125 +0,0 @@ -// SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. - -#include - -#include "ck/ck.hpp" -#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" -#include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_d_xdl_cshuffle.hpp" -#include "ck/tensor_operation/gpu/device/convolution_forward_specialization.hpp" -#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" -#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" - -#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" - -namespace ck { -namespace tensor_operation { -namespace device { -namespace instance { - -using Empty_Tuple = ck::Tuple<>; - -template -using S = ck::Sequence; - -using NDHWGC = ck::tensor_layout::convolution::NDHWGC; -using KZYXGC = ck::tensor_layout::convolution::KZYXGC; -using NDHWGK = ck::tensor_layout::convolution::NDHWGK; - -using PassThrough = ck::tensor_operation::element_wise::PassThrough; - -static constexpr auto ConvFwdDefault = - ck::tensor_operation::device::ConvolutionForwardSpecialization::Default; - -static constexpr auto ConvFwd1x1P0 = - ck::tensor_operation::device::ConvolutionForwardSpecialization::Filter1x1Pad0; - -static constexpr auto ConvFwd1x1S1P0 = - ck::tensor_operation::device::ConvolutionForwardSpecialization::Filter1x1Stride1Pad0; - -static constexpr auto GemmMNKPadding = ck::tensor_operation::device::GemmSpecialization::MNKPadding; - -// in[g, n, di, hi, wi, c] * wei[g, k, z, y, x, c] = out[g, n, do, ho, wo, k] -using device_grouped_conv3d_fwd_xdl_ndhwgc_kzyxgc_ndhwgk_int8_instances = std::tuple< - // clang-format off - // Default - //########################################| NumDim| A| B| Ds| E| AData| BData| AccData| CShuffle| Ds| EData| A| B| CDE| ConvForward| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| - //########################################| Spatial| Layout| Layout| Layout| Layout| Type| Type| Type| DataType| DataType| Type| Elementwise| Elementwise| Elementwise| Specialization| Specialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| - //########################################| | | | | | | | | | | | Operation| Operation| Operation| | | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| - //########################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | - DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 3, NDHWGC, KZYXGC, Empty_Tuple, NDHWGK, int8_t, int8_t, int32_t, int8_t, Empty_Tuple, int8_t, PassThrough, PassThrough, PassThrough, ConvFwdDefault, GemmMNKPadding, 1, 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, - DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 3, NDHWGC, KZYXGC, Empty_Tuple, NDHWGK, int8_t, int8_t, int32_t, int8_t, Empty_Tuple, int8_t, PassThrough, PassThrough, PassThrough, ConvFwdDefault, GemmMNKPadding, 1, 256, 128, 256, 32, 8, 8, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, - DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 3, NDHWGC, KZYXGC, Empty_Tuple, NDHWGK, int8_t, int8_t, int32_t, int8_t, Empty_Tuple, int8_t, PassThrough, PassThrough, PassThrough, ConvFwdDefault, GemmMNKPadding, 1, 128, 128, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 8>, - DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 3, NDHWGC, KZYXGC, Empty_Tuple, NDHWGK, int8_t, int8_t, int32_t, int8_t, Empty_Tuple, int8_t, PassThrough, PassThrough, PassThrough, ConvFwdDefault, GemmMNKPadding, 1, 256, 128, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, - DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 3, NDHWGC, KZYXGC, Empty_Tuple, NDHWGK, int8_t, int8_t, int32_t, int8_t, Empty_Tuple, int8_t, PassThrough, PassThrough, PassThrough, ConvFwdDefault, GemmMNKPadding, 1, 128, 128, 64, 32, 8, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 8>, - DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 3, NDHWGC, KZYXGC, Empty_Tuple, NDHWGK, int8_t, int8_t, int32_t, int8_t, Empty_Tuple, int8_t, PassThrough, PassThrough, PassThrough, ConvFwdDefault, GemmMNKPadding, 1, 128, 64, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 8>, - DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 3, NDHWGC, KZYXGC, Empty_Tuple, NDHWGK, int8_t, int8_t, int32_t, int8_t, Empty_Tuple, int8_t, PassThrough, PassThrough, PassThrough, ConvFwdDefault, GemmMNKPadding, 1, 64, 64, 64, 32, 8, 8, 32, 32, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 8>, - DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 3, NDHWGC, KZYXGC, Empty_Tuple, NDHWGK, int8_t, int8_t, int32_t, int8_t, Empty_Tuple, int8_t, PassThrough, PassThrough, PassThrough, ConvFwdDefault, GemmMNKPadding, 1, 256, 128, 64, 32, 8, 8, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, - DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 3, NDHWGC, KZYXGC, Empty_Tuple, NDHWGK, int8_t, int8_t, int32_t, int8_t, Empty_Tuple, int8_t, PassThrough, PassThrough, PassThrough, ConvFwdDefault, GemmMNKPadding, 1, 256, 64, 128, 32, 8, 8, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, - DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 3, NDHWGC, KZYXGC, Empty_Tuple, NDHWGK, int8_t, int8_t, int32_t, int8_t, Empty_Tuple, int8_t, PassThrough, PassThrough, PassThrough, ConvFwdDefault, GemmMNKPadding, 1, 128, 128, 32, 32, 8, 8, 32, 32, 2, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 8>, - DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 3, NDHWGC, KZYXGC, Empty_Tuple, NDHWGK, int8_t, int8_t, int32_t, int8_t, Empty_Tuple, int8_t, PassThrough, PassThrough, PassThrough, ConvFwdDefault, GemmMNKPadding, 1, 128, 32, 128, 32, 8, 8, 32, 32, 1, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 8>, - DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 3, NDHWGC, KZYXGC, Empty_Tuple, NDHWGK, int8_t, int8_t, int32_t, int8_t, Empty_Tuple, int8_t, PassThrough, PassThrough, PassThrough, ConvFwdDefault, GemmMNKPadding, 1, 64, 64, 32, 32, 8, 8, 32, 32, 2, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 8>, - DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 3, NDHWGC, KZYXGC, Empty_Tuple, NDHWGK, int8_t, int8_t, int32_t, int8_t, Empty_Tuple, int8_t, PassThrough, PassThrough, PassThrough, ConvFwdDefault, GemmMNKPadding, 1, 64, 32, 64, 32, 8, 8, 32, 32, 1, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 8>, - - // Filter1x1Pad0 - //########################################| NumDim| A| B| Ds| E| AData| BData| AccData| CShuffle| Ds| EData| A| B| CDE| ConvForward| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| - //########################################| Spatial| Layout| Layout| Layout| Layout| Type| Type| Type| DataType| DataType| Type| Elementwise| Elementwise| Elementwise| Specialization| Specialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| - //########################################| | | | | | | | | | | | Operation| Operation| Operation| | | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| - //########################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | - DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 3, NDHWGC, KZYXGC, Empty_Tuple, NDHWGK, int8_t, int8_t, int32_t, int8_t, Empty_Tuple, int8_t, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, GemmMNKPadding, 1, 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, - DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 3, NDHWGC, KZYXGC, Empty_Tuple, NDHWGK, int8_t, int8_t, int32_t, int8_t, Empty_Tuple, int8_t, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, GemmMNKPadding, 1, 256, 128, 256, 32, 8, 8, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, - DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 3, NDHWGC, KZYXGC, Empty_Tuple, NDHWGK, int8_t, int8_t, int32_t, int8_t, Empty_Tuple, int8_t, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, GemmMNKPadding, 1, 128, 128, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 8>, - DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 3, NDHWGC, KZYXGC, Empty_Tuple, NDHWGK, int8_t, int8_t, int32_t, int8_t, Empty_Tuple, int8_t, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, GemmMNKPadding, 1, 256, 128, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, - DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 3, NDHWGC, KZYXGC, Empty_Tuple, NDHWGK, int8_t, int8_t, int32_t, int8_t, Empty_Tuple, int8_t, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, GemmMNKPadding, 1, 128, 128, 64, 32, 8, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 8>, - DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 3, NDHWGC, KZYXGC, Empty_Tuple, NDHWGK, int8_t, int8_t, int32_t, int8_t, Empty_Tuple, int8_t, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, GemmMNKPadding, 1, 128, 64, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 8>, - DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 3, NDHWGC, KZYXGC, Empty_Tuple, NDHWGK, int8_t, int8_t, int32_t, int8_t, Empty_Tuple, int8_t, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, GemmMNKPadding, 1, 64, 64, 64, 32, 8, 8, 32, 32, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 8>, - DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 3, NDHWGC, KZYXGC, Empty_Tuple, NDHWGK, int8_t, int8_t, int32_t, int8_t, Empty_Tuple, int8_t, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, GemmMNKPadding, 1, 256, 128, 64, 32, 8, 8, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, - DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 3, NDHWGC, KZYXGC, Empty_Tuple, NDHWGK, int8_t, int8_t, int32_t, int8_t, Empty_Tuple, int8_t, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, GemmMNKPadding, 1, 256, 64, 128, 32, 8, 8, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, - DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 3, NDHWGC, KZYXGC, Empty_Tuple, NDHWGK, int8_t, int8_t, int32_t, int8_t, Empty_Tuple, int8_t, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, GemmMNKPadding, 1, 128, 128, 32, 32, 8, 8, 32, 32, 2, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 8>, - DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 3, NDHWGC, KZYXGC, Empty_Tuple, NDHWGK, int8_t, int8_t, int32_t, int8_t, Empty_Tuple, int8_t, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, GemmMNKPadding, 1, 128, 32, 128, 32, 8, 8, 32, 32, 1, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 8>, - DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 3, NDHWGC, KZYXGC, Empty_Tuple, NDHWGK, int8_t, int8_t, int32_t, int8_t, Empty_Tuple, int8_t, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, GemmMNKPadding, 1, 64, 64, 32, 32, 8, 8, 32, 32, 2, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 8>, - DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 3, NDHWGC, KZYXGC, Empty_Tuple, NDHWGK, int8_t, int8_t, int32_t, int8_t, Empty_Tuple, int8_t, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, GemmMNKPadding, 1, 64, 32, 64, 32, 8, 8, 32, 32, 1, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 8>, - - // Filter1x1Stride1Pad0 - //########################################| NumDim| A| B| Ds| E| AData| BData| AccData| CShuffle| Ds| EData| A| B| CDE| ConvForward| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| - //########################################| Spatial| Layout| Layout| Layout| Layout| Type| Type| Type| DataType| DataType| Type| Elementwise| Elementwise| Elementwise| Specialization| Specialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| - //########################################| | | | | | | | | | | | Operation| Operation| Operation| | | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| - //########################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | - DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 3, NDHWGC, KZYXGC, Empty_Tuple, NDHWGK, int8_t, int8_t, int32_t, int8_t, Empty_Tuple, int8_t, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, GemmMNKPadding, 1, 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, - DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 3, NDHWGC, KZYXGC, Empty_Tuple, NDHWGK, int8_t, int8_t, int32_t, int8_t, Empty_Tuple, int8_t, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, GemmMNKPadding, 1, 256, 128, 256, 32, 8, 8, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, - DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 3, NDHWGC, KZYXGC, Empty_Tuple, NDHWGK, int8_t, int8_t, int32_t, int8_t, Empty_Tuple, int8_t, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, GemmMNKPadding, 1, 128, 128, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 8>, - DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 3, NDHWGC, KZYXGC, Empty_Tuple, NDHWGK, int8_t, int8_t, int32_t, int8_t, Empty_Tuple, int8_t, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, GemmMNKPadding, 1, 256, 128, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, - DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 3, NDHWGC, KZYXGC, Empty_Tuple, NDHWGK, int8_t, int8_t, int32_t, int8_t, Empty_Tuple, int8_t, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, GemmMNKPadding, 1, 128, 128, 64, 32, 8, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 8>, - DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 3, NDHWGC, KZYXGC, Empty_Tuple, NDHWGK, int8_t, int8_t, int32_t, int8_t, Empty_Tuple, int8_t, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, GemmMNKPadding, 1, 128, 64, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 8>, - DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 3, NDHWGC, KZYXGC, Empty_Tuple, NDHWGK, int8_t, int8_t, int32_t, int8_t, Empty_Tuple, int8_t, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, GemmMNKPadding, 1, 64, 64, 64, 32, 8, 8, 32, 32, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 8>, - DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 3, NDHWGC, KZYXGC, Empty_Tuple, NDHWGK, int8_t, int8_t, int32_t, int8_t, Empty_Tuple, int8_t, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, GemmMNKPadding, 1, 256, 128, 64, 32, 8, 8, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, - DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 3, NDHWGC, KZYXGC, Empty_Tuple, NDHWGK, int8_t, int8_t, int32_t, int8_t, Empty_Tuple, int8_t, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, GemmMNKPadding, 1, 256, 64, 128, 32, 8, 8, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, - DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 3, NDHWGC, KZYXGC, Empty_Tuple, NDHWGK, int8_t, int8_t, int32_t, int8_t, Empty_Tuple, int8_t, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, GemmMNKPadding, 1, 128, 128, 32, 32, 8, 8, 32, 32, 2, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 8>, - DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 3, NDHWGC, KZYXGC, Empty_Tuple, NDHWGK, int8_t, int8_t, int32_t, int8_t, Empty_Tuple, int8_t, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, GemmMNKPadding, 1, 128, 32, 128, 32, 8, 8, 32, 32, 1, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 8>, - DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 3, NDHWGC, KZYXGC, Empty_Tuple, NDHWGK, int8_t, int8_t, int32_t, int8_t, Empty_Tuple, int8_t, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, GemmMNKPadding, 1, 64, 64, 32, 32, 8, 8, 32, 32, 2, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 8>, - DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 3, NDHWGC, KZYXGC, Empty_Tuple, NDHWGK, int8_t, int8_t, int32_t, int8_t, Empty_Tuple, int8_t, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, GemmMNKPadding, 1, 64, 32, 64, 32, 8, 8, 32, 32, 1, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 8> - // clang-format on - >; - -void add_device_grouped_conv3d_fwd_xdl_ndhwgc_kzyxgc_ndhwgk_int8_instances( - std::vector>>& instances) -{ - add_device_operation_instances( - instances, device_grouped_conv3d_fwd_xdl_ndhwgc_kzyxgc_ndhwgk_int8_instances{}); -} - -} // namespace instance -} // namespace device -} // namespace tensor_operation -} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/grouped_gemm/CMakeLists.txt b/library/src/tensor_operation_instance/gpu/grouped_gemm/CMakeLists.txt index b973b70aacc86a5e5baadd2e5b72f17704754fa4..f1553e3d5bb8b7f3992ee5bf713d4b1d93747e07 100644 --- a/library/src/tensor_operation_instance/gpu/grouped_gemm/CMakeLists.txt +++ b/library/src/tensor_operation_instance/gpu/grouped_gemm/CMakeLists.txt @@ -1,3 +1,4 @@ +if(DTYPES MATCHES "fp16" OR NOT DEFINED DTYPES) add_instance_library(device_grouped_gemm_instance device_grouped_gemm_xdl_f16_f16_f16_mk_kn_mn_instance.cpp device_grouped_gemm_xdl_f16_f16_f16_mk_nk_mn_instance.cpp @@ -8,3 +9,4 @@ add_instance_library(device_grouped_gemm_instance device_grouped_gemm_xdl_splitk_f16_f16_f16_mk_kn_mn_irregular_instance.cpp device_grouped_gemm_xdl_splitk_f16_f16_f16_mk_nk_mn_irregular_instance.cpp ) +endif() diff --git a/library/src/tensor_operation_instance/gpu/grouped_gemm/device_grouped_gemm_xdl_f16_f16_f16_km_kn_mn_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_gemm/device_grouped_gemm_xdl_f16_f16_f16_km_kn_mn_instance.cpp index b550bb28716e14d6618dd06ed06802211c92c866..aa161e51c989b2128a9a1e1c5ad055d2678ae737 100644 --- a/library/src/tensor_operation_instance/gpu/grouped_gemm/device_grouped_gemm_xdl_f16_f16_f16_km_kn_mn_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/grouped_gemm/device_grouped_gemm_xdl_f16_f16_f16_km_kn_mn_instance.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #include diff --git a/library/src/tensor_operation_instance/gpu/grouped_gemm/device_grouped_gemm_xdl_f16_f16_f16_km_nk_mn_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_gemm/device_grouped_gemm_xdl_f16_f16_f16_km_nk_mn_instance.cpp index a3f9c7a9e73f6f6db24a6526bf5e8195842a33ac..c454deac1a3a045f51a853d1546a70500360cc15 100644 --- a/library/src/tensor_operation_instance/gpu/grouped_gemm/device_grouped_gemm_xdl_f16_f16_f16_km_nk_mn_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/grouped_gemm/device_grouped_gemm_xdl_f16_f16_f16_km_nk_mn_instance.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #include diff --git a/library/src/tensor_operation_instance/gpu/grouped_gemm/device_grouped_gemm_xdl_f16_f16_f16_mk_kn_mn_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_gemm/device_grouped_gemm_xdl_f16_f16_f16_mk_kn_mn_instance.cpp index 5f5d6c9b5a95c259fc6de7e5acae2a9ad41e06c9..c829e8863d2c1156c1754ee3ae73c47a5f84a07a 100644 --- a/library/src/tensor_operation_instance/gpu/grouped_gemm/device_grouped_gemm_xdl_f16_f16_f16_mk_kn_mn_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/grouped_gemm/device_grouped_gemm_xdl_f16_f16_f16_mk_kn_mn_instance.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #include diff --git a/library/src/tensor_operation_instance/gpu/grouped_gemm/device_grouped_gemm_xdl_f16_f16_f16_mk_nk_mn_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_gemm/device_grouped_gemm_xdl_f16_f16_f16_mk_nk_mn_instance.cpp index 2ace1b24320b215fc5d45f21aec69ddcbc346cff..fb30e7a97317aa74d589d0987011b54dc67dc62b 100644 --- a/library/src/tensor_operation_instance/gpu/grouped_gemm/device_grouped_gemm_xdl_f16_f16_f16_mk_nk_mn_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/grouped_gemm/device_grouped_gemm_xdl_f16_f16_f16_mk_nk_mn_instance.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #include diff --git a/library/src/tensor_operation_instance/gpu/grouped_gemm/device_grouped_gemm_xdl_splitk_f16_f16_f16_mk_kn_mn_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_gemm/device_grouped_gemm_xdl_splitk_f16_f16_f16_mk_kn_mn_instance.cpp index 764ec061981ecc814a5c2a88fa01c98be17f0720..8642562fa38f25a7e487d7b4194ee8f8f16beb91 100644 --- a/library/src/tensor_operation_instance/gpu/grouped_gemm/device_grouped_gemm_xdl_splitk_f16_f16_f16_mk_kn_mn_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/grouped_gemm/device_grouped_gemm_xdl_splitk_f16_f16_f16_mk_kn_mn_instance.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #include diff --git a/library/src/tensor_operation_instance/gpu/grouped_gemm/device_grouped_gemm_xdl_splitk_f16_f16_f16_mk_kn_mn_irregular_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_gemm/device_grouped_gemm_xdl_splitk_f16_f16_f16_mk_kn_mn_irregular_instance.cpp index a3d73440eb58db82a98c7342da0c3ed07a5890b8..90223fd9bd3606e43539b53c4880f13d4bc09f9c 100644 --- a/library/src/tensor_operation_instance/gpu/grouped_gemm/device_grouped_gemm_xdl_splitk_f16_f16_f16_mk_kn_mn_irregular_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/grouped_gemm/device_grouped_gemm_xdl_splitk_f16_f16_f16_mk_kn_mn_irregular_instance.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #include @@ -41,26 +41,47 @@ using device_grouped_gemm_xdl_splitk_f16_f16_f16_mk_kn_mn_irregular_tile_instanc // DeviceGroupedGemmXdlSplitKCShuffle< Row, Row, Empty_Tuple, Row, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 128, 128, 64, 32, 8, 2, 32, 32, 2, 2, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 8, 16, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 4, 2, 0, 1, 1, S<1, 32, 1, 4>, 8>, // DeviceGroupedGemmXdlSplitKCShuffle< Row, Row, Empty_Tuple, Row, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 128, 64, 128, 32, 8, 2, 32, 32, 2, 2, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 4, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 4, 2, 0, 1, 1, S<1, 16, 1, 8>, 8>, - DeviceGroupedGemmXdlSplitKCShuffle< Row, Row, Empty_Tuple, Row, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 4, 64, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 2, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, - DeviceGroupedGemmXdlSplitKCShuffle< Row, Row, Empty_Tuple, Row, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 256, 128, 256, 32, 8, 8, 32, 32, 2, 4, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 4, 64, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 4, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, - DeviceGroupedGemmXdlSplitKCShuffle< Row, Row, Empty_Tuple, Row, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 256, 192, 64, 32, 8, 8, 32, 32, 3, 1, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 4, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 2, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, - DeviceGroupedGemmXdlSplitKCShuffle< Row, Row, Empty_Tuple, Row, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 256, 64, 192, 32, 8, 8, 32, 32, 1, 3, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 4, 48, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 2, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, - DeviceGroupedGemmXdlSplitKCShuffle< Row, Row, Empty_Tuple, Row, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 256, 128, 128, 32, 8, 8, 32, 32, 2, 2, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 4, 64, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 2, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, - DeviceGroupedGemmXdlSplitKCShuffle< Row, Row, Empty_Tuple, Row, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 256, 128, 64, 32, 8, 8, 32, 32, 2, 1, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 4, 64, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 1, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, - DeviceGroupedGemmXdlSplitKCShuffle< Row, Row, Empty_Tuple, Row, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 256, 64, 128, 32, 8, 8, 32, 32, 1, 2, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 4, 64, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 2, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, - DeviceGroupedGemmXdlSplitKCShuffle< Row, Row, Empty_Tuple, Row, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 128, 128, 128, 32, 8, 8, 32, 32, 4, 2, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 4, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 4, 8, 1, 1, 1, S<1, 16, 1, 8>, 8>, - DeviceGroupedGemmXdlSplitKCShuffle< Row, Row, Empty_Tuple, Row, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 128, 128, 64, 32, 8, 8, 32, 32, 2, 2, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 4, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 2, 8, 1, 1, 1, S<1, 32, 1, 4>, 8>, - DeviceGroupedGemmXdlSplitKCShuffle< Row, Row, Empty_Tuple, Row, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 128, 64, 128, 32, 8, 8, 32, 32, 2, 2, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 4, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 4, 8, 1, 1, 1, S<1, 16, 1, 8>, 8>, - DeviceGroupedGemmXdlSplitKCShuffle< Row, Row, Empty_Tuple, Row, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 128, 32, 192, 32, 8, 8, 32, 32, 1, 3, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 4, 24, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 8>, - DeviceGroupedGemmXdlSplitKCShuffle< Row, Row, Empty_Tuple, Row, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 128, 192, 32, 32, 8, 8, 32, 32, 3, 1, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 4, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 1, 8, 1, 1, 1, S<1, 32, 1, 4>, 8>, - DeviceGroupedGemmXdlSplitKCShuffle< Row, Row, Empty_Tuple, Row, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 128, 32, 64, 32, 8, 8, 32, 32, 1, 1, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 4, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 2, 8, 1, 1, 1, S<1, 32, 1, 4>, 8>, - DeviceGroupedGemmXdlSplitKCShuffle< Row, Row, Empty_Tuple, Row, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 128, 64, 32, 32, 8, 8, 32, 32, 1, 1, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 4, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 1, 8, 1, 1, 1, S<1, 32, 1, 4>, 8>, - DeviceGroupedGemmXdlSplitKCShuffle< Row, Row, Empty_Tuple, Row, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 128, 32, 128, 32, 8, 8, 32, 32, 1, 2, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 4, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 4, 8, 1, 1, 1, S<1, 16, 1, 8>, 8>, - DeviceGroupedGemmXdlSplitKCShuffle< Row, Row, Empty_Tuple, Row, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 128, 128, 32, 32, 8, 8, 32, 32, 2, 1, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 4, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 1, 8, 1, 1, 1, S<1, 32, 1, 4>, 8>, - DeviceGroupedGemmXdlSplitKCShuffle< Row, Row, Empty_Tuple, Row, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 128, 64, 64, 32, 8, 8, 32, 32, 2, 1, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 4, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 2, 8, 1, 1, 1, S<1, 16, 1, 8>, 8>, - DeviceGroupedGemmXdlSplitKCShuffle< Row, Row, Empty_Tuple, Row, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 64, 64, 64, 32, 8, 8, 32, 32, 2, 2, S<1, 4, 16, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 4, 16, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 2, 8, 1, 1, 1, S<1, 16, 1, 4>, 8>, - DeviceGroupedGemmXdlSplitKCShuffle< Row, Row, Empty_Tuple, Row, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 64, 64, 32, 32, 8, 8, 32, 32, 2, 1, S<1, 4, 16, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 4, 16, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 2, 8, 1, 1, 1, S<1, 16, 1, 4>, 8>, - DeviceGroupedGemmXdlSplitKCShuffle< Row, Row, Empty_Tuple, Row, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 64, 32, 64, 32, 8, 8, 32, 32, 1, 2, S<1, 4, 16, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 4, 16, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 2, 8, 1, 1, 1, S<1, 16, 1, 4>, 8> + DeviceGroupedGemmXdlSplitKCShuffle< Row, Row, Empty_Tuple, Row, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 4, 64, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 2, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, PipelineVersion::v1>, + DeviceGroupedGemmXdlSplitKCShuffle< Row, Row, Empty_Tuple, Row, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 256, 128, 256, 32, 8, 8, 32, 32, 2, 4, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 4, 64, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 4, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, PipelineVersion::v1>, + DeviceGroupedGemmXdlSplitKCShuffle< Row, Row, Empty_Tuple, Row, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 256, 192, 64, 32, 8, 8, 32, 32, 3, 1, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 4, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 2, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, PipelineVersion::v1>, + DeviceGroupedGemmXdlSplitKCShuffle< Row, Row, Empty_Tuple, Row, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 256, 64, 192, 32, 8, 8, 32, 32, 1, 3, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 4, 48, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 2, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, PipelineVersion::v1>, + DeviceGroupedGemmXdlSplitKCShuffle< Row, Row, Empty_Tuple, Row, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 256, 128, 128, 32, 8, 8, 32, 32, 2, 2, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 4, 64, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 2, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, PipelineVersion::v1>, + DeviceGroupedGemmXdlSplitKCShuffle< Row, Row, Empty_Tuple, Row, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 256, 128, 64, 32, 8, 8, 32, 32, 2, 1, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 4, 64, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 1, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, PipelineVersion::v1>, + DeviceGroupedGemmXdlSplitKCShuffle< Row, Row, Empty_Tuple, Row, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 256, 64, 128, 32, 8, 8, 32, 32, 1, 2, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 4, 64, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 2, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, PipelineVersion::v1>, + DeviceGroupedGemmXdlSplitKCShuffle< Row, Row, Empty_Tuple, Row, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 128, 128, 128, 32, 8, 8, 32, 32, 4, 2, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 4, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 4, 8, 1, 1, 1, S<1, 16, 1, 8>, 8, PipelineVersion::v1>, + DeviceGroupedGemmXdlSplitKCShuffle< Row, Row, Empty_Tuple, Row, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 128, 128, 64, 32, 8, 8, 32, 32, 2, 2, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 4, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 2, 8, 1, 1, 1, S<1, 32, 1, 4>, 8, PipelineVersion::v1>, + DeviceGroupedGemmXdlSplitKCShuffle< Row, Row, Empty_Tuple, Row, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 128, 64, 128, 32, 8, 8, 32, 32, 2, 2, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 4, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 4, 8, 1, 1, 1, S<1, 16, 1, 8>, 8, PipelineVersion::v1>, + DeviceGroupedGemmXdlSplitKCShuffle< Row, Row, Empty_Tuple, Row, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 128, 32, 192, 32, 8, 8, 32, 32, 1, 3, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 4, 24, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 8, PipelineVersion::v1>, + DeviceGroupedGemmXdlSplitKCShuffle< Row, Row, Empty_Tuple, Row, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 128, 192, 32, 32, 8, 8, 32, 32, 3, 1, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 4, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 1, 8, 1, 1, 1, S<1, 32, 1, 4>, 8, PipelineVersion::v1>, + DeviceGroupedGemmXdlSplitKCShuffle< Row, Row, Empty_Tuple, Row, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 128, 32, 64, 32, 8, 8, 32, 32, 1, 1, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 4, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 2, 8, 1, 1, 1, S<1, 32, 1, 4>, 8, PipelineVersion::v1>, + DeviceGroupedGemmXdlSplitKCShuffle< Row, Row, Empty_Tuple, Row, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 128, 64, 32, 32, 8, 8, 32, 32, 1, 1, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 4, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 1, 8, 1, 1, 1, S<1, 32, 1, 4>, 8, PipelineVersion::v1>, + DeviceGroupedGemmXdlSplitKCShuffle< Row, Row, Empty_Tuple, Row, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 128, 32, 128, 32, 8, 8, 32, 32, 1, 2, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 4, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 4, 8, 1, 1, 1, S<1, 16, 1, 8>, 8, PipelineVersion::v1>, + DeviceGroupedGemmXdlSplitKCShuffle< Row, Row, Empty_Tuple, Row, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 128, 128, 32, 32, 8, 8, 32, 32, 2, 1, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 4, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 1, 8, 1, 1, 1, S<1, 32, 1, 4>, 8, PipelineVersion::v1>, + DeviceGroupedGemmXdlSplitKCShuffle< Row, Row, Empty_Tuple, Row, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 128, 64, 64, 32, 8, 8, 32, 32, 2, 1, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 4, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 2, 8, 1, 1, 1, S<1, 16, 1, 8>, 8, PipelineVersion::v1>, + DeviceGroupedGemmXdlSplitKCShuffle< Row, Row, Empty_Tuple, Row, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 64, 64, 64, 32, 8, 8, 32, 32, 2, 2, S<1, 4, 16, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 4, 16, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 2, 8, 1, 1, 1, S<1, 16, 1, 4>, 8, PipelineVersion::v1>, + DeviceGroupedGemmXdlSplitKCShuffle< Row, Row, Empty_Tuple, Row, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 64, 64, 32, 32, 8, 8, 32, 32, 2, 1, S<1, 4, 16, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 4, 16, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 2, 8, 1, 1, 1, S<1, 16, 1, 4>, 8, PipelineVersion::v1>, + DeviceGroupedGemmXdlSplitKCShuffle< Row, Row, Empty_Tuple, Row, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 64, 32, 64, 32, 8, 8, 32, 32, 1, 2, S<1, 4, 16, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 4, 16, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 2, 8, 1, 1, 1, S<1, 16, 1, 4>, 8, PipelineVersion::v1>, + + DeviceGroupedGemmXdlSplitKCShuffle< Row, Row, Empty_Tuple, Row, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 4, 64, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 2, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, PipelineVersion::v2>, + DeviceGroupedGemmXdlSplitKCShuffle< Row, Row, Empty_Tuple, Row, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 256, 128, 256, 32, 8, 8, 32, 32, 2, 4, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 4, 64, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 4, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, PipelineVersion::v2>, + DeviceGroupedGemmXdlSplitKCShuffle< Row, Row, Empty_Tuple, Row, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 256, 192, 64, 32, 8, 8, 32, 32, 3, 1, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 4, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 2, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, PipelineVersion::v2>, + DeviceGroupedGemmXdlSplitKCShuffle< Row, Row, Empty_Tuple, Row, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 256, 64, 192, 32, 8, 8, 32, 32, 1, 3, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 4, 48, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 2, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, PipelineVersion::v2>, + DeviceGroupedGemmXdlSplitKCShuffle< Row, Row, Empty_Tuple, Row, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 256, 128, 128, 32, 8, 8, 32, 32, 2, 2, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 4, 64, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 2, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, PipelineVersion::v2>, + DeviceGroupedGemmXdlSplitKCShuffle< Row, Row, Empty_Tuple, Row, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 256, 128, 64, 32, 8, 8, 32, 32, 2, 1, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 4, 64, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 1, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, PipelineVersion::v2>, + DeviceGroupedGemmXdlSplitKCShuffle< Row, Row, Empty_Tuple, Row, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 256, 64, 128, 32, 8, 8, 32, 32, 1, 2, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 4, 64, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 2, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, PipelineVersion::v2>, + DeviceGroupedGemmXdlSplitKCShuffle< Row, Row, Empty_Tuple, Row, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 128, 128, 128, 32, 8, 8, 32, 32, 4, 2, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 4, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 4, 8, 1, 1, 1, S<1, 16, 1, 8>, 8, PipelineVersion::v2>, + DeviceGroupedGemmXdlSplitKCShuffle< Row, Row, Empty_Tuple, Row, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 128, 128, 64, 32, 8, 8, 32, 32, 2, 2, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 4, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 2, 8, 1, 1, 1, S<1, 32, 1, 4>, 8, PipelineVersion::v2>, + DeviceGroupedGemmXdlSplitKCShuffle< Row, Row, Empty_Tuple, Row, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 128, 64, 128, 32, 8, 8, 32, 32, 2, 2, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 4, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 4, 8, 1, 1, 1, S<1, 16, 1, 8>, 8, PipelineVersion::v2>, + DeviceGroupedGemmXdlSplitKCShuffle< Row, Row, Empty_Tuple, Row, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 128, 32, 192, 32, 8, 8, 32, 32, 1, 3, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 4, 24, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 8, PipelineVersion::v2>, + DeviceGroupedGemmXdlSplitKCShuffle< Row, Row, Empty_Tuple, Row, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 128, 192, 32, 32, 8, 8, 32, 32, 3, 1, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 4, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 1, 8, 1, 1, 1, S<1, 32, 1, 4>, 8, PipelineVersion::v2>, + DeviceGroupedGemmXdlSplitKCShuffle< Row, Row, Empty_Tuple, Row, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 128, 32, 64, 32, 8, 8, 32, 32, 1, 1, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 4, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 2, 8, 1, 1, 1, S<1, 32, 1, 4>, 8, PipelineVersion::v2>, + DeviceGroupedGemmXdlSplitKCShuffle< Row, Row, Empty_Tuple, Row, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 128, 64, 32, 32, 8, 8, 32, 32, 1, 1, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 4, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 1, 8, 1, 1, 1, S<1, 32, 1, 4>, 8, PipelineVersion::v2>, + DeviceGroupedGemmXdlSplitKCShuffle< Row, Row, Empty_Tuple, Row, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 128, 32, 128, 32, 8, 8, 32, 32, 1, 2, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 4, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 4, 8, 1, 1, 1, S<1, 16, 1, 8>, 8, PipelineVersion::v2>, + DeviceGroupedGemmXdlSplitKCShuffle< Row, Row, Empty_Tuple, Row, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 128, 128, 32, 32, 8, 8, 32, 32, 2, 1, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 4, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 1, 8, 1, 1, 1, S<1, 32, 1, 4>, 8, PipelineVersion::v2>, + DeviceGroupedGemmXdlSplitKCShuffle< Row, Row, Empty_Tuple, Row, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 128, 64, 64, 32, 8, 8, 32, 32, 2, 1, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 4, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 2, 8, 1, 1, 1, S<1, 16, 1, 8>, 8, PipelineVersion::v2>, + DeviceGroupedGemmXdlSplitKCShuffle< Row, Row, Empty_Tuple, Row, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 64, 64, 64, 32, 8, 8, 32, 32, 2, 2, S<1, 4, 16, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 4, 16, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 2, 8, 1, 1, 1, S<1, 16, 1, 4>, 8, PipelineVersion::v2>, + DeviceGroupedGemmXdlSplitKCShuffle< Row, Row, Empty_Tuple, Row, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 64, 64, 32, 32, 8, 8, 32, 32, 2, 1, S<1, 4, 16, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 4, 16, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 2, 8, 1, 1, 1, S<1, 16, 1, 4>, 8, PipelineVersion::v2>, + DeviceGroupedGemmXdlSplitKCShuffle< Row, Row, Empty_Tuple, Row, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 64, 32, 64, 32, 8, 8, 32, 32, 1, 2, S<1, 4, 16, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 4, 16, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 2, 8, 1, 1, 1, S<1, 16, 1, 4>, 8, PipelineVersion::v2> // clang-format on >; diff --git a/library/src/tensor_operation_instance/gpu/grouped_gemm/device_grouped_gemm_xdl_splitk_f16_f16_f16_mk_nk_mn_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_gemm/device_grouped_gemm_xdl_splitk_f16_f16_f16_mk_nk_mn_instance.cpp index cffb0fce17728ae09075047bc5f1e571b7381fb2..aa6365cd98c039c36007b875d425e3cef37ba782 100644 --- a/library/src/tensor_operation_instance/gpu/grouped_gemm/device_grouped_gemm_xdl_splitk_f16_f16_f16_mk_nk_mn_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/grouped_gemm/device_grouped_gemm_xdl_splitk_f16_f16_f16_mk_nk_mn_instance.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #include diff --git a/library/src/tensor_operation_instance/gpu/grouped_gemm/device_grouped_gemm_xdl_splitk_f16_f16_f16_mk_nk_mn_irregular_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_gemm/device_grouped_gemm_xdl_splitk_f16_f16_f16_mk_nk_mn_irregular_instance.cpp index dddfa2aa44745f4484bb9e2171e7c8c5b0bfa4e0..f4460b360be563123fdf7e02c43b49008dfc43ea 100644 --- a/library/src/tensor_operation_instance/gpu/grouped_gemm/device_grouped_gemm_xdl_splitk_f16_f16_f16_mk_nk_mn_irregular_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/grouped_gemm/device_grouped_gemm_xdl_splitk_f16_f16_f16_mk_nk_mn_irregular_instance.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #include diff --git a/library/src/tensor_operation_instance/gpu/grouped_gemm_bias/CMakeLists.txt b/library/src/tensor_operation_instance/gpu/grouped_gemm_bias/CMakeLists.txt new file mode 100644 index 0000000000000000000000000000000000000000..ef8a440c1a98d4d596e0163cf5b42de26a43d917 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/grouped_gemm_bias/CMakeLists.txt @@ -0,0 +1,7 @@ +add_instance_library(device_grouped_gemm_bias_instance + device_grouped_gemm_xdl_fixed_nk_bias_f16_f16_f16_mk_kn_mn_instance.cpp + device_grouped_gemm_xdl_fixed_nk_bias_f16_f16_f16_mk_nk_mn_instance.cpp + + device_grouped_gemm_xdl_fixed_nk_bias_f16_f16_f32_mk_kn_mn_instance.cpp + device_grouped_gemm_xdl_fixed_nk_bias_f16_f16_f32_mk_nk_mn_instance.cpp +) diff --git a/library/src/tensor_operation_instance/gpu/grouped_gemm_bias/device_grouped_gemm_xdl_fixed_nk_bias_f16_f16_f16_mk_kn_mn_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_gemm_bias/device_grouped_gemm_xdl_fixed_nk_bias_f16_f16_f16_mk_kn_mn_instance.cpp new file mode 100644 index 0000000000000000000000000000000000000000..28be904e65cd8b6690bb2b71c5ad9986d2cafa75 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/grouped_gemm_bias/device_grouped_gemm_xdl_fixed_nk_bias_f16_f16_f16_mk_kn_mn_instance.cpp @@ -0,0 +1,79 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_grouped_gemm_xdl_fixed_nk.hpp" + +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +using F16 = ck::half_t; +using F32 = float; + +using Row = ck::tensor_layout::gemm::RowMajor; +using Col = ck::tensor_layout::gemm::ColumnMajor; + +template +using S = ck::Sequence; + +using D0DataType = F32; +using DsDataType = ck::Tuple; + +using D0Layout = Row; +using DsLayout = ck::Tuple; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; +using Add = ck::tensor_operation::element_wise::Add; + +static constexpr auto GemmMNKPadding = ck::tensor_operation::device::GemmSpecialization::MNKPadding; + +using device_grouped_gemm_xdl_fixed_nk_bias_f16_f16_f16_mk_kn_mn_irregular_tile_instances = + std::tuple< + // clang-format off + //############################| A| B| Ds| E| AData| BData| AccData| CShuffle| DsData| EData| A| B| C| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| + //############################| Layout| Layout| Layout| Layout| Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| + //############################| | | | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| + //############################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceGroupedGemm_Xdl_Fixed_NK< Row, Row, DsLayout, Row, F16, F16, F32, F32, DsDataType, F16, PassThrough, PassThrough, Add, GemmMNKPadding, 1, 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, S< 1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 4, 64, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 2, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGroupedGemm_Xdl_Fixed_NK< Row, Row, DsLayout, Row, F16, F16, F32, F32, DsDataType, F16, PassThrough, PassThrough, Add, GemmMNKPadding, 1, 256, 128, 128, 32, 8, 8, 32, 32, 2, 2, S< 1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 4, 64, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 1, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGroupedGemm_Xdl_Fixed_NK< Row, Row, DsLayout, Row, F16, F16, F32, F32, DsDataType, F16, PassThrough, PassThrough, Add, GemmMNKPadding, 1, 256, 128, 64, 32, 8, 2, 32, 32, 2, 1, S< 1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 16,16, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 4, 2, 0, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGroupedGemm_Xdl_Fixed_NK< Row, Row, DsLayout, Row, F16, F16, F32, F32, DsDataType, F16, PassThrough, PassThrough, Add, GemmMNKPadding, 1, 256, 128, 64, 32, 8, 8, 32, 32, 2, 1, S< 1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 4, 64, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 1, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGroupedGemm_Xdl_Fixed_NK< Row, Row, DsLayout, Row, F16, F16, F32, F32, DsDataType, F16, PassThrough, PassThrough, Add, GemmMNKPadding, 1, 256, 64, 128, 32, 8, 2, 32, 32, 1, 2, S< 1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 8, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 4, 2, 0, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGroupedGemm_Xdl_Fixed_NK< Row, Row, DsLayout, Row, F16, F16, F32, F32, DsDataType, F16, PassThrough, PassThrough, Add, GemmMNKPadding, 1, 256, 64, 128, 32, 8, 8, 32, 32, 1, 2, S< 1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 4, 64, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 2, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGroupedGemm_Xdl_Fixed_NK< Row, Row, DsLayout, Row, F16, F16, F32, F32, DsDataType, F16, PassThrough, PassThrough, Add, GemmMNKPadding, 1, 128, 128, 64, 32, 8, 2, 32, 32, 2, 2, S< 1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 8, 16, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 4, 2, 0, 1, 1, S<1, 32, 1, 4>, 8>, + DeviceGroupedGemm_Xdl_Fixed_NK< Row, Row, DsLayout, Row, F16, F16, F32, F32, DsDataType, F16, PassThrough, PassThrough, Add, GemmMNKPadding, 1, 128, 128, 64, 32, 8, 8, 32, 32, 2, 2, S< 1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 4, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 2, 8, 1, 1, 1, S<1, 32, 1, 4>, 8>, + DeviceGroupedGemm_Xdl_Fixed_NK< Row, Row, DsLayout, Row, F16, F16, F32, F32, DsDataType, F16, PassThrough, PassThrough, Add, GemmMNKPadding, 1, 128, 64, 128, 32, 8, 2, 32, 32, 2, 2, S< 1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 4, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 4, 2, 0, 1, 1, S<1, 16, 1, 8>, 8>, + DeviceGroupedGemm_Xdl_Fixed_NK< Row, Row, DsLayout, Row, F16, F16, F32, F32, DsDataType, F16, PassThrough, PassThrough, Add, GemmMNKPadding, 1, 128, 64, 128, 32, 8, 8, 32, 32, 2, 2, S< 1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 4, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 4, 8, 1, 1, 1, S<1, 16, 1, 8>, 8> + // clang-format on + >; + +void add_device_grouped_gemm_xdl_fixed_nk_bias_f16_f16_f16_mk_kn_mn_instances( + std::vector>>& instances) +{ + add_device_operation_instances( + instances, + device_grouped_gemm_xdl_fixed_nk_bias_f16_f16_f16_mk_kn_mn_irregular_tile_instances{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/grouped_gemm_bias/device_grouped_gemm_xdl_fixed_nk_bias_f16_f16_f16_mk_nk_mn_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_gemm_bias/device_grouped_gemm_xdl_fixed_nk_bias_f16_f16_f16_mk_nk_mn_instance.cpp new file mode 100644 index 0000000000000000000000000000000000000000..5f5b86a0e431a154d0fb249f3b09fc06c87657dd --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/grouped_gemm_bias/device_grouped_gemm_xdl_fixed_nk_bias_f16_f16_f16_mk_nk_mn_instance.cpp @@ -0,0 +1,82 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_grouped_gemm_xdl_fixed_nk.hpp" + +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +using F16 = ck::half_t; +using F32 = float; + +using Row = ck::tensor_layout::gemm::RowMajor; +using Col = ck::tensor_layout::gemm::ColumnMajor; + +template +using S = ck::Sequence; + +using D0DataType = F32; +using DsDataType = ck::Tuple; + +using D0Layout = Row; +using DsLayout = ck::Tuple; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; +using Add = ck::tensor_operation::element_wise::Add; + +static constexpr auto GemmMNKPadding = ck::tensor_operation::device::GemmSpecialization::MNKPadding; + +using device_grouped_gemm_xdl_fixed_nk_bias_f16_f16_f16_mk_nk_mn_irregular_tile_instances = + std::tuple< + // clang-format off + //############################| A| B| Ds| E| AData| BData| AccData| CShuffle| DsData| EData| A| B| CDE| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| + //############################| Layout| Layout| Layout| Layout| Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| + //############################| | | | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| + //############################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceGroupedGemm_Xdl_Fixed_NK< Row, Col, DsLayout, Row, F16, F16, F32, F32, DsDataType, F16, PassThrough, PassThrough, Add, GemmMNKPadding, 1, 256, 128, 256, 64, 8, 8, 32, 32, 2, 4, S<1, 8, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 8, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGroupedGemm_Xdl_Fixed_NK< Row, Col, DsLayout, Row, F16, F16, F32, F32, DsDataType, F16, PassThrough, PassThrough, Add, GemmMNKPadding, 1, 256, 128, 128, 64, 8, 8, 32, 32, 2, 2, S<1, 8, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 8, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGroupedGemm_Xdl_Fixed_NK< Row, Col, DsLayout, Row, F16, F16, F32, F32, DsDataType, F16, PassThrough, PassThrough, Add, GemmMNKPadding, 1, 256, 128, 64, 64, 8, 8, 32, 32, 2, 1, S<1, 8, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 8, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGroupedGemm_Xdl_Fixed_NK< Row, Col, DsLayout, Row, F16, F16, F32, F32, DsDataType, F16, PassThrough, PassThrough, Add, GemmMNKPadding, 1, 256, 64, 128, 64, 8, 8, 32, 32, 1, 2, S<1, 8, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 8, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGroupedGemm_Xdl_Fixed_NK< Row, Col, DsLayout, Row, F16, F16, F32, F32, DsDataType, F16, PassThrough, PassThrough, Add, GemmMNKPadding, 1, 128, 128, 128, 64, 8, 8, 32, 32, 4, 2, S<1, 8, 16, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 8, 16, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 8>, + DeviceGroupedGemm_Xdl_Fixed_NK< Row, Col, DsLayout, Row, F16, F16, F32, F32, DsDataType, F16, PassThrough, PassThrough, Add, GemmMNKPadding, 1, 128, 128, 64, 64, 8, 8, 32, 32, 2, 2, S<1, 8, 16, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 8, 16, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 8>, + DeviceGroupedGemm_Xdl_Fixed_NK< Row, Col, DsLayout, Row, F16, F16, F32, F32, DsDataType, F16, PassThrough, PassThrough, Add, GemmMNKPadding, 1, 128, 64, 128, 64, 8, 8, 32, 32, 2, 2, S<1, 8, 16, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 8, 16, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 8>, + DeviceGroupedGemm_Xdl_Fixed_NK< Row, Col, DsLayout, Row, F16, F16, F32, F32, DsDataType, F16, PassThrough, PassThrough, Add, GemmMNKPadding, 1, 128, 128, 32, 64, 8, 8, 32, 32, 2, 1, S<1, 8, 16, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 8, 16, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 8>, + DeviceGroupedGemm_Xdl_Fixed_NK< Row, Col, DsLayout, Row, F16, F16, F32, F32, DsDataType, F16, PassThrough, PassThrough, Add, GemmMNKPadding, 1, 128, 32, 128, 64, 8, 8, 32, 32, 1, 2, S<1, 8, 16, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 8, 16, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 8>, + DeviceGroupedGemm_Xdl_Fixed_NK< Row, Col, DsLayout, Row, F16, F16, F32, F32, DsDataType, F16, PassThrough, PassThrough, Add, GemmMNKPadding, 1, 128, 32, 256, 64, 8, 8, 32, 32, 1, 4, S<1, 8, 16, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 8, 16, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 8>, + DeviceGroupedGemm_Xdl_Fixed_NK< Row, Col, DsLayout, Row, F16, F16, F32, F32, DsDataType, F16, PassThrough, PassThrough, Add, GemmMNKPadding, 1, 64, 64, 64, 64, 8, 8, 32, 32, 2, 2, S<1, 8, 8, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 8, 8, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 8>, + DeviceGroupedGemm_Xdl_Fixed_NK< Row, Col, DsLayout, Row, F16, F16, F32, F32, DsDataType, F16, PassThrough, PassThrough, Add, GemmMNKPadding, 1, 64, 64, 32, 64, 8, 8, 32, 32, 2, 1, S<1, 8, 8, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 8, 8, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 8>, + DeviceGroupedGemm_Xdl_Fixed_NK< Row, Col, DsLayout, Row, F16, F16, F32, F32, DsDataType, F16, PassThrough, PassThrough, Add, GemmMNKPadding, 1, 64, 32, 64, 64, 8, 8, 32, 32, 1, 2, S<1, 8, 8, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 8, 8, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 8> + // clang-format on + >; + +void add_device_grouped_gemm_xdl_fixed_nk_bias_f16_f16_f16_mk_nk_mn_instances( + std::vector>>& instances) +{ + add_device_operation_instances( + instances, + device_grouped_gemm_xdl_fixed_nk_bias_f16_f16_f16_mk_nk_mn_irregular_tile_instances{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/grouped_gemm_bias/device_grouped_gemm_xdl_fixed_nk_bias_f16_f16_f32_mk_kn_mn_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_gemm_bias/device_grouped_gemm_xdl_fixed_nk_bias_f16_f16_f32_mk_kn_mn_instance.cpp new file mode 100644 index 0000000000000000000000000000000000000000..fa8441431a06b7b560cb918d13a5742a2b556446 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/grouped_gemm_bias/device_grouped_gemm_xdl_fixed_nk_bias_f16_f16_f32_mk_kn_mn_instance.cpp @@ -0,0 +1,83 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_grouped_gemm_xdl_fixed_nk.hpp" + +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +using F16 = ck::half_t; +using F32 = float; + +using Row = ck::tensor_layout::gemm::RowMajor; +using Col = ck::tensor_layout::gemm::ColumnMajor; + +template +using S = ck::Sequence; + +using D0DataType = F32; +using DsDataType = ck::Tuple; + +using D0Layout = Row; +using DsLayout = ck::Tuple; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; +using Add = ck::tensor_operation::element_wise::Add; + +static constexpr auto GemmMNKPadding = ck::tensor_operation::device::GemmSpecialization::MNKPadding; + +using device_grouped_gemm_xdl_fixed_nk_bias_f16_f16_f32_mk_kn_mn_irregular_tile_instances = + std::tuple< + // clang-format off + //############################| A| B| Ds| E| AData| BData| AccData| CShuffle| DsData| EData| A| B| C| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| + //############################| Layout| Layout| Layout| Layout| Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| + //############################| | | | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| + //############################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceGroupedGemm_Xdl_Fixed_NK< Row, Row, DsLayout, Row, F16, F16, F32, F32, DsDataType, F32, PassThrough, PassThrough, Add, GemmMNKPadding, 1, 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, S< 1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 4, 64, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 2, 8, 1, 1, 1, S<1, 32, 1, 8>, 4>, + DeviceGroupedGemm_Xdl_Fixed_NK< Row, Row, DsLayout, Row, F16, F16, F32, F32, DsDataType, F32, PassThrough, PassThrough, Add, GemmMNKPadding, 1, 256, 128, 128, 32, 8, 8, 32, 32, 2, 2, S< 1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 4, 64, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 1, 8, 1, 1, 1, S<1, 32, 1, 8>, 4>, + DeviceGroupedGemm_Xdl_Fixed_NK< Row, Row, DsLayout, Row, F16, F16, F32, F32, DsDataType, F32, PassThrough, PassThrough, Add, GemmMNKPadding, 1, 256, 128, 64, 32, 8, 2, 32, 32, 2, 1, S< 1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 16,16, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 4, 2, 0, 1, 1, S<1, 32, 1, 8>, 4>, + DeviceGroupedGemm_Xdl_Fixed_NK< Row, Row, DsLayout, Row, F16, F16, F32, F32, DsDataType, F32, PassThrough, PassThrough, Add, GemmMNKPadding, 1, 256, 128, 64, 32, 8, 8, 32, 32, 2, 1, S< 1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 4, 64, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 1, 8, 1, 1, 1, S<1, 32, 1, 8>, 4>, + DeviceGroupedGemm_Xdl_Fixed_NK< Row, Row, DsLayout, Row, F16, F16, F32, F32, DsDataType, F32, PassThrough, PassThrough, Add, GemmMNKPadding, 1, 256, 64, 128, 32, 8, 2, 32, 32, 1, 2, S< 1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 8, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 4, 2, 0, 1, 1, S<1, 32, 1, 8>, 4>, + DeviceGroupedGemm_Xdl_Fixed_NK< Row, Row, DsLayout, Row, F16, F16, F32, F32, DsDataType, F32, PassThrough, PassThrough, Add, GemmMNKPadding, 1, 256, 64, 128, 32, 8, 8, 32, 32, 1, 2, S< 1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 4, 64, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 2, 8, 1, 1, 1, S<1, 32, 1, 8>, 4>, + DeviceGroupedGemm_Xdl_Fixed_NK< Row, Row, DsLayout, Row, F16, F16, F32, F32, DsDataType, F32, PassThrough, PassThrough, Add, GemmMNKPadding, 1, 256, 32, 128, 32, 8, 8, 32, 32, 1, 1, S< 1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 4, 64, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 2, 8, 1, 1, 1, S<1, 32, 1, 8>, 4>, + DeviceGroupedGemm_Xdl_Fixed_NK< Row, Row, DsLayout, Row, F16, F16, F32, F32, DsDataType, F32, PassThrough, PassThrough, Add, GemmMNKPadding, 1, 128, 128, 64, 32, 8, 2, 32, 32, 2, 2, S< 1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 8, 16, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 4, 2, 0, 1, 1, S<1, 32, 1, 4>, 4>, + DeviceGroupedGemm_Xdl_Fixed_NK< Row, Row, DsLayout, Row, F16, F16, F32, F32, DsDataType, F32, PassThrough, PassThrough, Add, GemmMNKPadding, 1, 128, 128, 64, 32, 8, 8, 32, 32, 2, 2, S< 1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 4, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 2, 8, 1, 1, 1, S<1, 32, 1, 4>, 4>, + DeviceGroupedGemm_Xdl_Fixed_NK< Row, Row, DsLayout, Row, F16, F16, F32, F32, DsDataType, F32, PassThrough, PassThrough, Add, GemmMNKPadding, 1, 128, 64, 128, 32, 8, 2, 32, 32, 2, 2, S< 1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 4, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 4, 2, 0, 1, 1, S<1, 16, 1, 8>, 4>, + DeviceGroupedGemm_Xdl_Fixed_NK< Row, Row, DsLayout, Row, F16, F16, F32, F32, DsDataType, F32, PassThrough, PassThrough, Add, GemmMNKPadding, 1, 128, 64, 128, 32, 8, 8, 32, 32, 2, 2, S< 1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 4, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 4, 8, 1, 1, 1, S<1, 16, 1, 8>, 4>, + DeviceGroupedGemm_Xdl_Fixed_NK< Row, Row, DsLayout, Row, F16, F16, F32, F32, DsDataType, F32, PassThrough, PassThrough, Add, GemmMNKPadding, 1, 128, 32, 128, 32, 8, 8, 32, 32, 1, 2, S< 1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 4, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 4, 8, 1, 1, 1, S<1, 16, 1, 8>, 4>, + DeviceGroupedGemm_Xdl_Fixed_NK< Row, Row, DsLayout, Row, F16, F16, F32, F32, DsDataType, F32, PassThrough, PassThrough, Add, GemmMNKPadding, 1, 256, 32, 256, 32, 8, 8, 32, 32, 1, 2, S< 1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 4, 64, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 4, 8, 1, 1, 1, S<1, 16, 1, 8>, 4>, + DeviceGroupedGemm_Xdl_Fixed_NK< Row, Row, DsLayout, Row, F16, F16, F32, F32, DsDataType, F32, PassThrough, PassThrough, Add, GemmMNKPadding, 1, 128, 16, 128, 32, 8, 8, 16, 16, 1, 4, S< 1, 4, 16, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 4, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 4, 8, 1, 1, 1, S<1, 16, 1, 8>, 4> + // clang-format on + >; + +void add_device_grouped_gemm_xdl_fixed_nk_bias_f16_f16_f32_mk_kn_mn_instances( + std::vector>>& instances) +{ + add_device_operation_instances( + instances, + device_grouped_gemm_xdl_fixed_nk_bias_f16_f16_f32_mk_kn_mn_irregular_tile_instances{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/grouped_gemm_bias/device_grouped_gemm_xdl_fixed_nk_bias_f16_f16_f32_mk_nk_mn_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_gemm_bias/device_grouped_gemm_xdl_fixed_nk_bias_f16_f16_f32_mk_nk_mn_instance.cpp new file mode 100644 index 0000000000000000000000000000000000000000..386a8856422b6b5a85c21b76df16609801a1f346 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/grouped_gemm_bias/device_grouped_gemm_xdl_fixed_nk_bias_f16_f16_f32_mk_nk_mn_instance.cpp @@ -0,0 +1,83 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_grouped_gemm_xdl_fixed_nk.hpp" + +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +using F16 = ck::half_t; +using F32 = float; + +using Row = ck::tensor_layout::gemm::RowMajor; +using Col = ck::tensor_layout::gemm::ColumnMajor; + +template +using S = ck::Sequence; + +using D0DataType = F32; +using DsDataType = ck::Tuple; + +using D0Layout = Row; +using DsLayout = ck::Tuple; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; +using Add = ck::tensor_operation::element_wise::Add; + +static constexpr auto GemmMNKPadding = ck::tensor_operation::device::GemmSpecialization::MNKPadding; + +using device_grouped_gemm_xdl_fixed_nk_bias_f16_f16_f32_mk_nk_mn_irregular_tile_instances = + std::tuple< + // clang-format off + //############################| A| B| Ds| E| AData| BData| AccData| CShuffle| DsData| EData| A| B| CDE| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| + //############################| Layout| Layout| Layout| Layout| Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| + //############################| | | | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| + //############################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceGroupedGemm_Xdl_Fixed_NK< Row, Col, DsLayout, Row, F16, F16, F32, F32, DsDataType, F32, PassThrough, PassThrough, Add, GemmMNKPadding, 1, 256, 128, 256, 64, 8, 8, 32, 32, 2, 4, S<1, 8, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 8, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 4>, + DeviceGroupedGemm_Xdl_Fixed_NK< Row, Col, DsLayout, Row, F16, F16, F32, F32, DsDataType, F32, PassThrough, PassThrough, Add, GemmMNKPadding, 1, 256, 128, 128, 64, 8, 8, 32, 32, 2, 2, S<1, 8, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 8, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 4>, + DeviceGroupedGemm_Xdl_Fixed_NK< Row, Col, DsLayout, Row, F16, F16, F32, F32, DsDataType, F32, PassThrough, PassThrough, Add, GemmMNKPadding, 1, 256, 128, 64, 64, 8, 8, 32, 32, 2, 1, S<1, 8, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 8, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 4>, + DeviceGroupedGemm_Xdl_Fixed_NK< Row, Col, DsLayout, Row, F16, F16, F32, F32, DsDataType, F32, PassThrough, PassThrough, Add, GemmMNKPadding, 1, 256, 64, 128, 64, 8, 8, 32, 32, 1, 2, S<1, 8, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 8, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 4>, + DeviceGroupedGemm_Xdl_Fixed_NK< Row, Col, DsLayout, Row, F16, F16, F32, F32, DsDataType, F32, PassThrough, PassThrough, Add, GemmMNKPadding, 1, 256, 32, 128, 64, 8, 8, 32, 32, 1, 1, S<1, 8, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 8, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 4>, + DeviceGroupedGemm_Xdl_Fixed_NK< Row, Col, DsLayout, Row, F16, F16, F32, F32, DsDataType, F32, PassThrough, PassThrough, Add, GemmMNKPadding, 1, 128, 128, 128, 64, 8, 8, 32, 32, 4, 2, S<1, 8, 16, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 8, 16, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 4>, + DeviceGroupedGemm_Xdl_Fixed_NK< Row, Col, DsLayout, Row, F16, F16, F32, F32, DsDataType, F32, PassThrough, PassThrough, Add, GemmMNKPadding, 1, 128, 128, 64, 64, 8, 8, 32, 32, 2, 2, S<1, 8, 16, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 8, 16, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 4>, + DeviceGroupedGemm_Xdl_Fixed_NK< Row, Col, DsLayout, Row, F16, F16, F32, F32, DsDataType, F32, PassThrough, PassThrough, Add, GemmMNKPadding, 1, 128, 64, 128, 64, 8, 8, 32, 32, 2, 2, S<1, 8, 16, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 8, 16, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 4>, + DeviceGroupedGemm_Xdl_Fixed_NK< Row, Col, DsLayout, Row, F16, F16, F32, F32, DsDataType, F32, PassThrough, PassThrough, Add, GemmMNKPadding, 1, 128, 128, 32, 64, 8, 8, 32, 32, 2, 1, S<1, 8, 16, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 8, 16, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 4>, + DeviceGroupedGemm_Xdl_Fixed_NK< Row, Col, DsLayout, Row, F16, F16, F32, F32, DsDataType, F32, PassThrough, PassThrough, Add, GemmMNKPadding, 1, 128, 32, 128, 64, 8, 8, 32, 32, 1, 2, S<1, 8, 16, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 8, 16, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 4>, + DeviceGroupedGemm_Xdl_Fixed_NK< Row, Col, DsLayout, Row, F16, F16, F32, F32, DsDataType, F32, PassThrough, PassThrough, Add, GemmMNKPadding, 1, 128, 32, 256, 64, 8, 8, 32, 32, 1, 4, S<1, 8, 16, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 8, 16, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 4>, + DeviceGroupedGemm_Xdl_Fixed_NK< Row, Col, DsLayout, Row, F16, F16, F32, F32, DsDataType, F32, PassThrough, PassThrough, Add, GemmMNKPadding, 1, 64, 64, 64, 64, 8, 8, 32, 32, 2, 2, S<1, 8, 8, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 8, 8, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 4>, + DeviceGroupedGemm_Xdl_Fixed_NK< Row, Col, DsLayout, Row, F16, F16, F32, F32, DsDataType, F32, PassThrough, PassThrough, Add, GemmMNKPadding, 1, 64, 64, 32, 64, 8, 8, 32, 32, 2, 1, S<1, 8, 8, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 8, 8, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 4>, + DeviceGroupedGemm_Xdl_Fixed_NK< Row, Col, DsLayout, Row, F16, F16, F32, F32, DsDataType, F32, PassThrough, PassThrough, Add, GemmMNKPadding, 1, 64, 32, 64, 64, 8, 8, 32, 32, 1, 2, S<1, 8, 8, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 8, 8, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 4> + // clang-format on + >; + +void add_device_grouped_gemm_xdl_fixed_nk_bias_f16_f16_f32_mk_nk_mn_instances( + std::vector>>& instances) +{ + add_device_operation_instances( + instances, + device_grouped_gemm_xdl_fixed_nk_bias_f16_f16_f32_mk_nk_mn_irregular_tile_instances{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/grouped_gemm_fastgelu/CMakeLists.txt b/library/src/tensor_operation_instance/gpu/grouped_gemm_fastgelu/CMakeLists.txt index 648f2146cbb84bfba719d175f3934c01e3da76fa..a45bf36399d361a85e15fdad3bed963100e564ac 100644 --- a/library/src/tensor_operation_instance/gpu/grouped_gemm_fastgelu/CMakeLists.txt +++ b/library/src/tensor_operation_instance/gpu/grouped_gemm_fastgelu/CMakeLists.txt @@ -1,6 +1,8 @@ +if(DTYPES MATCHES "fp16" OR NOT DEFINED DTYPES) add_instance_library(device_grouped_gemm_fastgelu_instance device_grouped_gemm_fastgelu_xdl_f16_f16_f16_mk_kn_mn_instance.cpp device_grouped_gemm_fastgelu_xdl_f16_f16_f16_mk_nk_mn_instance.cpp device_grouped_gemm_fastgelu_xdl_f16_f16_f16_km_kn_mn_instance.cpp device_grouped_gemm_fastgelu_xdl_f16_f16_f16_km_nk_mn_instance.cpp ) +endif() diff --git a/library/src/tensor_operation_instance/gpu/grouped_gemm_fastgelu/device_grouped_gemm_fastgelu_xdl_f16_f16_f16_km_kn_mn_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_gemm_fastgelu/device_grouped_gemm_fastgelu_xdl_f16_f16_f16_km_kn_mn_instance.cpp index c2f5f00c7acb62da4a901bc965d237aeb46bcf3a..f4086b6eac4dd74903fbddb955d5bd7f41ee748a 100644 --- a/library/src/tensor_operation_instance/gpu/grouped_gemm_fastgelu/device_grouped_gemm_fastgelu_xdl_f16_f16_f16_km_kn_mn_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/grouped_gemm_fastgelu/device_grouped_gemm_fastgelu_xdl_f16_f16_f16_km_kn_mn_instance.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #include diff --git a/library/src/tensor_operation_instance/gpu/grouped_gemm_fastgelu/device_grouped_gemm_fastgelu_xdl_f16_f16_f16_km_nk_mn_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_gemm_fastgelu/device_grouped_gemm_fastgelu_xdl_f16_f16_f16_km_nk_mn_instance.cpp index 476d4ce1f87a0d2d5c109ba847273edbb31cd8e1..d68eb76144579b073f8119d431910696e6b4af9e 100644 --- a/library/src/tensor_operation_instance/gpu/grouped_gemm_fastgelu/device_grouped_gemm_fastgelu_xdl_f16_f16_f16_km_nk_mn_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/grouped_gemm_fastgelu/device_grouped_gemm_fastgelu_xdl_f16_f16_f16_km_nk_mn_instance.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #include diff --git a/library/src/tensor_operation_instance/gpu/grouped_gemm_fastgelu/device_grouped_gemm_fastgelu_xdl_f16_f16_f16_mk_kn_mn_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_gemm_fastgelu/device_grouped_gemm_fastgelu_xdl_f16_f16_f16_mk_kn_mn_instance.cpp index 1023fa48102b7078b320be22beb19881cb678cba..2dfb8caace5144fa29934af425b005c28e9d77c1 100644 --- a/library/src/tensor_operation_instance/gpu/grouped_gemm_fastgelu/device_grouped_gemm_fastgelu_xdl_f16_f16_f16_mk_kn_mn_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/grouped_gemm_fastgelu/device_grouped_gemm_fastgelu_xdl_f16_f16_f16_mk_kn_mn_instance.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #include diff --git a/library/src/tensor_operation_instance/gpu/grouped_gemm_fastgelu/device_grouped_gemm_fastgelu_xdl_f16_f16_f16_mk_nk_mn_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_gemm_fastgelu/device_grouped_gemm_fastgelu_xdl_f16_f16_f16_mk_nk_mn_instance.cpp index 6b065c0f82ac87adb5acbbf21cd470392890bd6d..598a0b0e29c22effb655b09247bc84ef6c1dd268 100644 --- a/library/src/tensor_operation_instance/gpu/grouped_gemm_fastgelu/device_grouped_gemm_fastgelu_xdl_f16_f16_f16_mk_nk_mn_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/grouped_gemm_fastgelu/device_grouped_gemm_fastgelu_xdl_f16_f16_f16_mk_nk_mn_instance.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #include diff --git a/library/src/tensor_operation_instance/gpu/image_to_column/CMakeLists.txt b/library/src/tensor_operation_instance/gpu/image_to_column/CMakeLists.txt new file mode 100644 index 0000000000000000000000000000000000000000..de10369374bdf2e25ad05735590cd1a94f96e27e --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/image_to_column/CMakeLists.txt @@ -0,0 +1,5 @@ +add_instance_library(device_image_to_column_instance + device_image_to_column_nhwc_1d_instance.cpp + device_image_to_column_nhwc_2d_instance.cpp + device_image_to_column_nhwc_3d_instance.cpp +) diff --git a/library/src/tensor_operation_instance/gpu/image_to_column/device_image_to_column_nhwc_1d_instance.cpp b/library/src/tensor_operation_instance/gpu/image_to_column/device_image_to_column_nhwc_1d_instance.cpp new file mode 100644 index 0000000000000000000000000000000000000000..c8463623c320fadc862ca9c219d798c9df4f2a24 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/image_to_column/device_image_to_column_nhwc_1d_instance.cpp @@ -0,0 +1,39 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#include "ck/library/tensor_operation_instance/gpu/image_to_column/device_image_to_column_instance.hpp" +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +void add_device_image_to_column_nhwc_1d_bf16_instances( + std::vector>>& instances) +{ + add_device_operation_instances(instances, device_image_to_column_bf16_instances<1, GNWC>{}); +} + +void add_device_image_to_column_nhwc_1d_f16_instances( + std::vector>>& instances) +{ + add_device_operation_instances(instances, device_image_to_column_f16_instances<1, GNWC>{}); +} + +void add_device_image_to_column_nhwc_1d_f32_instances( + std::vector>>& instances) +{ + add_device_operation_instances(instances, device_image_to_column_f32_instances<1, GNWC>{}); +} + +void add_device_image_to_column_nhwc_1d_i8_instances( + std::vector>>& instances) +{ + add_device_operation_instances(instances, device_image_to_column_i8_instances<1, GNWC>{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/image_to_column/device_image_to_column_nhwc_2d_instance.cpp b/library/src/tensor_operation_instance/gpu/image_to_column/device_image_to_column_nhwc_2d_instance.cpp new file mode 100644 index 0000000000000000000000000000000000000000..652c7fac2a973df3ba024c43facee646f22d851c --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/image_to_column/device_image_to_column_nhwc_2d_instance.cpp @@ -0,0 +1,39 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#include "ck/library/tensor_operation_instance/gpu/image_to_column/device_image_to_column_instance.hpp" +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +void add_device_image_to_column_nhwc_2d_bf16_instances( + std::vector>>& instances) +{ + add_device_operation_instances(instances, device_image_to_column_bf16_instances<2, GNHWC>{}); +} + +void add_device_image_to_column_nhwc_2d_f16_instances( + std::vector>>& instances) +{ + add_device_operation_instances(instances, device_image_to_column_f16_instances<2, GNHWC>{}); +} + +void add_device_image_to_column_nhwc_2d_f32_instances( + std::vector>>& instances) +{ + add_device_operation_instances(instances, device_image_to_column_f32_instances<2, GNHWC>{}); +} + +void add_device_image_to_column_nhwc_2d_i8_instances( + std::vector>>& instances) +{ + add_device_operation_instances(instances, device_image_to_column_i8_instances<2, GNHWC>{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/image_to_column/device_image_to_column_nhwc_3d_instance.cpp b/library/src/tensor_operation_instance/gpu/image_to_column/device_image_to_column_nhwc_3d_instance.cpp new file mode 100644 index 0000000000000000000000000000000000000000..07774504d7b16c0ef3744090613c6037ec75735f --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/image_to_column/device_image_to_column_nhwc_3d_instance.cpp @@ -0,0 +1,39 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#include "ck/library/tensor_operation_instance/gpu/image_to_column/device_image_to_column_instance.hpp" +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +void add_device_image_to_column_nhwc_3d_bf16_instances( + std::vector>>& instances) +{ + add_device_operation_instances(instances, device_image_to_column_bf16_instances<3, GNDHWC>{}); +} + +void add_device_image_to_column_nhwc_3d_f16_instances( + std::vector>>& instances) +{ + add_device_operation_instances(instances, device_image_to_column_f16_instances<3, GNDHWC>{}); +} + +void add_device_image_to_column_nhwc_3d_f32_instances( + std::vector>>& instances) +{ + add_device_operation_instances(instances, device_image_to_column_f32_instances<3, GNDHWC>{}); +} + +void add_device_image_to_column_nhwc_3d_i8_instances( + std::vector>>& instances) +{ + add_device_operation_instances(instances, device_image_to_column_i8_instances<3, GNDHWC>{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/max_pool_bwd/CMakeLists.txt b/library/src/tensor_operation_instance/gpu/max_pool_bwd/CMakeLists.txt new file mode 100644 index 0000000000000000000000000000000000000000..820a63480e3ddb90432867740c94a0754019c48d --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/max_pool_bwd/CMakeLists.txt @@ -0,0 +1,11 @@ +set(DEVICE_MAXPOOL_BWD_INSTANCES) +if(DTYPES MATCHES "fp16" OR NOT DEFINED DTYPES) + list(APPEND DEVICE_MAXPOOL_BWD_INSTANCES device_max_pool_bwd_f16_instance.cpp) +endif() +if(DTYPES MATCHES "bf16" OR NOT DEFINED DTYPES) + list(APPEND DEVICE_MAXPOOL_BWD_INSTANCES device_max_pool_bwd_bf16_instance.cpp) +endif() +if(DTYPES MATCHES "fp32" OR NOT DEFINED DTYPES) + list(APPEND DEVICE_MAXPOOL_BWD_INSTANCES device_max_pool_bwd_f32_instance.cpp) +endif() +add_instance_library(device_max_pool_bwd_instance ${DEVICE_MAXPOOL_BWD_INSTANCES}) diff --git a/library/src/tensor_operation_instance/gpu/max_pool_bwd/device_max_pool_bwd_bf16_instance.cpp b/library/src/tensor_operation_instance/gpu/max_pool_bwd/device_max_pool_bwd_bf16_instance.cpp new file mode 100644 index 0000000000000000000000000000000000000000..40628d58b2baccb754ee9086f850a8168f227187 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/max_pool_bwd/device_max_pool_bwd_bf16_instance.cpp @@ -0,0 +1,20 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#include "max_pool_bwd_instance_common.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +void add_device_maxpool_bwd_bf16_instances( + std::vector>>& instances) +{ + add_device_operation_instances(instances, device_maxpool_bwd_instances{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/max_pool_bwd/device_max_pool_bwd_f16_instance.cpp b/library/src/tensor_operation_instance/gpu/max_pool_bwd/device_max_pool_bwd_f16_instance.cpp new file mode 100644 index 0000000000000000000000000000000000000000..6c5cb27a7a8368d5a224dc898a63cd35fc8a920a --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/max_pool_bwd/device_max_pool_bwd_f16_instance.cpp @@ -0,0 +1,20 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#include "max_pool_bwd_instance_common.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +void add_device_maxpool_bwd_f16_instances( + std::vector>>& instances) +{ + add_device_operation_instances(instances, device_maxpool_bwd_instances{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/max_pool_bwd/device_max_pool_bwd_f32_instance.cpp b/library/src/tensor_operation_instance/gpu/max_pool_bwd/device_max_pool_bwd_f32_instance.cpp new file mode 100644 index 0000000000000000000000000000000000000000..0a8d5a797690ebe203df43381f9cbeb1532592dc --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/max_pool_bwd/device_max_pool_bwd_f32_instance.cpp @@ -0,0 +1,20 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#include "max_pool_bwd_instance_common.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +void add_device_maxpool_bwd_f32_instances( + std::vector>>& instances) +{ + add_device_operation_instances(instances, device_maxpool_bwd_instances{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/max_pool_bwd/max_pool_bwd_instance_common.hpp b/library/src/tensor_operation_instance/gpu/max_pool_bwd/max_pool_bwd_instance_common.hpp new file mode 100644 index 0000000000000000000000000000000000000000..0bba106ee25b01054d398186f13024e6d7d2a22a --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/max_pool_bwd/max_pool_bwd_instance_common.hpp @@ -0,0 +1,35 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_max_pool_bwd_impl.hpp" +#include "ck/utility/data_type.hpp" + +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +using I32 = int32_t; +using F16 = ck::half_t; +using BF16 = ck::bhalf_t; +using F32 = float; + +template +using device_maxpool_bwd_instances = + // clang-format off + std::tuple < + DeviceMaxPoolBwdImpl, + DeviceMaxPoolBwdImpl, + DeviceMaxPoolBwdImpl + // clang-format on + >; + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/normalization/CMakeLists.txt b/library/src/tensor_operation_instance/gpu/normalization/CMakeLists.txt index 176fb2fbee77418b877b0b76d0a2086dcce2da92..4892a9c183d0b5ad5155270a9161a5aaa9724e00 100644 --- a/library/src/tensor_operation_instance/gpu/normalization/CMakeLists.txt +++ b/library/src/tensor_operation_instance/gpu/normalization/CMakeLists.txt @@ -1,11 +1,15 @@ -add_instance_library(device_normalization_instance - device_layernorm2d_f16_instance.cpp - device_layernorm2d_f32_instance.cpp +set(DEVICE_NORMALIZATION_INSTANCES) +if(DTYPES MATCHES "fp16" OR NOT DEFINED DTYPES) + list(APPEND DEVICE_NORMALIZATION_INSTANCES device_layernorm2d_f16_instance.cpp device_layernorm4d_f16_instance.cpp - device_layernorm4d_f32_instance.cpp device_groupnorm_f16_instance.cpp - device_groupnorm_f32_instance.cpp device_groupnorm_swish_f16_instance.cpp - device_groupnorm_swish_f32_instance.cpp - device_groupnorm_swish_f16_f32_f32_f16_instance.cpp -) + device_groupnorm_swish_f16_f32_f32_f16_instance.cpp) +endif() +if(DTYPES MATCHES "fp32" OR NOT DEFINED DTYPES) + list(APPEND DEVICE_NORMALIZATION_INSTANCES device_layernorm2d_f32_instance.cpp + device_layernorm4d_f32_instance.cpp + device_groupnorm_f32_instance.cpp + device_groupnorm_swish_f32_instance.cpp) +endif() +add_instance_library(device_normalization_instance ${DEVICE_NORMALIZATION_INSTANCES}) diff --git a/library/src/tensor_operation_instance/gpu/normalization/device_groupnorm_f16_instance.cpp b/library/src/tensor_operation_instance/gpu/normalization/device_groupnorm_f16_instance.cpp index e9c2112e16e79513d246afe635011bd8620db0b4..762da1c6ae4b98744f0040d012a6643f27e0fdce 100644 --- a/library/src/tensor_operation_instance/gpu/normalization/device_groupnorm_f16_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/normalization/device_groupnorm_f16_instance.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #include "normalization_instance_common.hpp" @@ -14,7 +14,11 @@ void add_device_normalization_rank_5_3_f16_instances( std::vector>>& instances) { + add_device_operation_instances(instances, + device_normalization_f16_generic_instance{}); add_device_operation_instances(instances, device_normalization_f16_instances{}); + add_device_operation_instances(instances, + device_normalization_splitk_f16_instances{}); } } // namespace instance diff --git a/library/src/tensor_operation_instance/gpu/normalization/device_groupnorm_f32_instance.cpp b/library/src/tensor_operation_instance/gpu/normalization/device_groupnorm_f32_instance.cpp index 79dde38fc90d370c8a483a0187a1e95f3abd9d01..44b553bd1604f41e9dce318cdb0e044ca722f461 100644 --- a/library/src/tensor_operation_instance/gpu/normalization/device_groupnorm_f32_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/normalization/device_groupnorm_f32_instance.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #include "normalization_instance_common.hpp" @@ -14,7 +14,11 @@ void add_device_normalization_rank_5_3_f32_instances( std::vector>>& instances) { + add_device_operation_instances(instances, + device_normalization_f32_generic_instance{}); add_device_operation_instances(instances, device_normalization_f32_instances{}); + add_device_operation_instances(instances, + device_normalization_splitk_f32_instances{}); } } // namespace instance diff --git a/library/src/tensor_operation_instance/gpu/normalization/device_groupnorm_swish_f16_f32_f32_f16_instance.cpp b/library/src/tensor_operation_instance/gpu/normalization/device_groupnorm_swish_f16_f32_f32_f16_instance.cpp index 9f6bf128faebd85190b93c2d8e6a08ab1c9f1baa..aa662b7dfe16ccfab5b0414328cdbab4459ed5b0 100644 --- a/library/src/tensor_operation_instance/gpu/normalization/device_groupnorm_swish_f16_f32_f32_f16_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/normalization/device_groupnorm_swish_f16_f32_f32_f16_instance.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #include "normalization_instance_common.hpp" @@ -14,8 +14,12 @@ void add_device_normalization_rank_5_3_swish_f16_f32_f32_f16_instances( std::vector>>& instances) { + add_device_operation_instances( + instances, device_normalization_f16_f32_f32_f16_generic_instance{}); add_device_operation_instances(instances, device_normalization_f16_f32_f32_f16_instances{}); + add_device_operation_instances( + instances, device_normalization_splitk_f16_f32_f32_f16_instances{}); } } // namespace instance diff --git a/library/src/tensor_operation_instance/gpu/normalization/device_groupnorm_swish_f16_instance.cpp b/library/src/tensor_operation_instance/gpu/normalization/device_groupnorm_swish_f16_instance.cpp index 6241e0338561fca8b6e286cec098c8edefcf6424..bc5cd801aee3a606337487aa224c15082e0d7442 100644 --- a/library/src/tensor_operation_instance/gpu/normalization/device_groupnorm_swish_f16_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/normalization/device_groupnorm_swish_f16_instance.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #include "normalization_instance_common.hpp" @@ -14,7 +14,11 @@ void add_device_normalization_rank_5_3_swish_f16_instances( std::vector>>& instances) { + add_device_operation_instances(instances, + device_normalization_f16_generic_instance{}); add_device_operation_instances(instances, device_normalization_f16_instances{}); + add_device_operation_instances(instances, + device_normalization_splitk_f16_instances{}); } } // namespace instance diff --git a/library/src/tensor_operation_instance/gpu/normalization/device_groupnorm_swish_f32_instance.cpp b/library/src/tensor_operation_instance/gpu/normalization/device_groupnorm_swish_f32_instance.cpp index b64328d5d07518e9ff4f259d35edb11811239e63..4b2ab3357002f1b48085540a2d95758f28b6691c 100644 --- a/library/src/tensor_operation_instance/gpu/normalization/device_groupnorm_swish_f32_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/normalization/device_groupnorm_swish_f32_instance.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #include "normalization_instance_common.hpp" @@ -14,7 +14,11 @@ void add_device_normalization_rank_5_3_swish_f32_instances( std::vector>>& instances) { + add_device_operation_instances(instances, + device_normalization_f32_generic_instance{}); add_device_operation_instances(instances, device_normalization_f32_instances{}); + add_device_operation_instances(instances, + device_normalization_splitk_f32_instances{}); } } // namespace instance diff --git a/library/src/tensor_operation_instance/gpu/normalization/device_layernorm2d_f16_instance.cpp b/library/src/tensor_operation_instance/gpu/normalization/device_layernorm2d_f16_instance.cpp index d6a2f6f2c1c9c86d37179bf6c7a00cdcf6e359bc..0d235f1fa7dd4b44d6f24b7a7f97e458fab15705 100644 --- a/library/src/tensor_operation_instance/gpu/normalization/device_layernorm2d_f16_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/normalization/device_layernorm2d_f16_instance.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #include "normalization_instance_common.hpp" @@ -14,7 +14,11 @@ void add_device_normalization_rank_2_1_f16_instances( std::vector>>& instances) { + add_device_operation_instances(instances, + device_normalization_f16_generic_instance{}); add_device_operation_instances(instances, device_normalization_f16_instances{}); + add_device_operation_instances(instances, + device_normalization_splitk_f16_instances{}); } } // namespace instance diff --git a/library/src/tensor_operation_instance/gpu/normalization/device_layernorm2d_f32_instance.cpp b/library/src/tensor_operation_instance/gpu/normalization/device_layernorm2d_f32_instance.cpp index 73097828e3b19017453f197d8eb4ae152cb10fa9..00039531e18489cdc217c0e19ee5fc3b3c59bcdf 100644 --- a/library/src/tensor_operation_instance/gpu/normalization/device_layernorm2d_f32_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/normalization/device_layernorm2d_f32_instance.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #include "normalization_instance_common.hpp" @@ -14,7 +14,11 @@ void add_device_normalization_rank_2_1_f32_instances( std::vector>>& instances) { + add_device_operation_instances(instances, + device_normalization_f32_generic_instance{}); add_device_operation_instances(instances, device_normalization_f32_instances{}); + add_device_operation_instances(instances, + device_normalization_splitk_f32_instances{}); } } // namespace instance diff --git a/library/src/tensor_operation_instance/gpu/normalization/device_layernorm4d_f16_instance.cpp b/library/src/tensor_operation_instance/gpu/normalization/device_layernorm4d_f16_instance.cpp index 507a683ee7a8a435604dba0a5c56969ba84d9a1b..6bc395006221048701a281c378a6f064f4556ff9 100644 --- a/library/src/tensor_operation_instance/gpu/normalization/device_layernorm4d_f16_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/normalization/device_layernorm4d_f16_instance.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #include "normalization_instance_common.hpp" @@ -14,7 +14,11 @@ void add_device_normalization_rank_4_3_f16_instances( std::vector>>& instances) { + add_device_operation_instances(instances, + device_normalization_f16_generic_instance{}); add_device_operation_instances(instances, device_normalization_f16_instances{}); + add_device_operation_instances(instances, + device_normalization_splitk_f16_instances{}); } } // namespace instance diff --git a/library/src/tensor_operation_instance/gpu/normalization/device_layernorm4d_f32_instance.cpp b/library/src/tensor_operation_instance/gpu/normalization/device_layernorm4d_f32_instance.cpp index ca1aa0c25cecbd4ff67df53b5e53961877faa55b..b387dc2f3ffdb6dc8f156fff74b9643af0d56343 100644 --- a/library/src/tensor_operation_instance/gpu/normalization/device_layernorm4d_f32_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/normalization/device_layernorm4d_f32_instance.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #include "normalization_instance_common.hpp" @@ -14,7 +14,11 @@ void add_device_normalization_rank_4_3_f32_instances( std::vector>>& instances) { + add_device_operation_instances(instances, + device_normalization_f32_generic_instance{}); add_device_operation_instances(instances, device_normalization_f32_instances{}); + add_device_operation_instances(instances, + device_normalization_splitk_f32_instances{}); } } // namespace instance diff --git a/library/src/tensor_operation_instance/gpu/normalization/normalization_instance_common.hpp b/library/src/tensor_operation_instance/gpu/normalization/normalization_instance_common.hpp index 9dea41e89d3ca119450905d222e95c77f95f0afe..7aa3da8eedf863634a78113980092685f664fa3f 100644 --- a/library/src/tensor_operation_instance/gpu/normalization/normalization_instance_common.hpp +++ b/library/src/tensor_operation_instance/gpu/normalization/normalization_instance_common.hpp @@ -1,10 +1,11 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #pragma once #include "ck/ck.hpp" #include "ck/tensor_operation/gpu/device/impl/device_normalization_impl.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_normalization_splitk_impl.hpp" #include "ck/utility/data_type.hpp" #include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" @@ -43,6 +44,39 @@ using device_normalization_f16_instances = // clang-format on >; +template +using device_normalization_splitk_f16_instances = + // clang-format off + std::tuple < + // XDataType, GammaDataType, BetaDataType, ComputeDataType, YDataType, Rank, NumReduceDim, BlockSize, MThreadClusterSize, KThreadClusterSize, MThreadSliceSize, KThreadSliceSize, XYSrcVectorDim, XSrcVectorSize, GammaSrcVectorDim, GammaSrcVectorSize, BetaSrcVectorDim, BetaSrcVectorSize, YDstVectorSize> + DeviceNormalizationSplitKImpl, // irregular size + DeviceNormalizationSplitKImpl, // irregular size + DeviceNormalizationSplitKImpl, // irregular size + DeviceNormalizationSplitKImpl, // irregular size + DeviceNormalizationSplitKImpl, // irregular size + DeviceNormalizationSplitKImpl, // irregular size + DeviceNormalizationSplitKImpl, + DeviceNormalizationSplitKImpl, + DeviceNormalizationSplitKImpl, + DeviceNormalizationSplitKImpl, + DeviceNormalizationSplitKImpl, + DeviceNormalizationSplitKImpl, + DeviceNormalizationSplitKImpl, + DeviceNormalizationSplitKImpl, + DeviceNormalizationSplitKImpl, + DeviceNormalizationSplitKImpl, + DeviceNormalizationSplitKImpl, + DeviceNormalizationSplitKImpl + // clang-format on + >; + +template +using device_normalization_f16_generic_instance = std::tuple< + // clang-format off + DeviceNormalizationImpl + // clang-format on + >; + template using device_normalization_f32_instances = std::tuple< // clang-format off @@ -69,6 +103,39 @@ using device_normalization_f32_instances = std::tuple< // clang-format on >; +template +using device_normalization_splitk_f32_instances = std::tuple< + // clang-format off + // XDataType, GammaDataType, BetaDataType, ComputeDataType, YDataType, Rank, NumReduceDim, BlockSize, MThreadClusterSize, KThreadClusterSize, MThreadSliceSize, KThreadSliceSize, XYSrcVectorDim, XSrcVectorSize, GammaSrcVectorSize, BetaSrcVectorSize, YDstVectorSize> + DeviceNormalizationSplitKImpl, // irregular size + DeviceNormalizationSplitKImpl, // irregular size + DeviceNormalizationSplitKImpl, // irregular size + DeviceNormalizationSplitKImpl, // irregular size + DeviceNormalizationSplitKImpl, // irregular size + DeviceNormalizationSplitKImpl, + DeviceNormalizationSplitKImpl, + DeviceNormalizationSplitKImpl, + DeviceNormalizationSplitKImpl, + DeviceNormalizationSplitKImpl, + DeviceNormalizationSplitKImpl, + DeviceNormalizationSplitKImpl, + DeviceNormalizationSplitKImpl, + DeviceNormalizationSplitKImpl, + DeviceNormalizationSplitKImpl, + DeviceNormalizationSplitKImpl, + DeviceNormalizationSplitKImpl, + DeviceNormalizationSplitKImpl, + DeviceNormalizationSplitKImpl + // clang-format on + >; + +template +using device_normalization_f32_generic_instance = std::tuple< + // clang-format off + DeviceNormalizationImpl + // clang-format on + >; + template using device_normalization_f16_f32_f32_f16_instances = std::tuple< // clang-format off @@ -95,6 +162,39 @@ using device_normalization_f16_f32_f32_f16_instances = std::tuple< // clang-format on >; +template +using device_normalization_splitk_f16_f32_f32_f16_instances = std::tuple< + // clang-format off + // XDataType, GammaDataType, BetaDataType, ComputeDataType, YDataType, Rank, NumReduceDim, BlockSize, MThreadClusterSize, KThreadClusterSize, MThreadSliceSize, KThreadSliceSize, XYSrcVectorDim, XSrcVectorSize, GammaSrcVectorSize, BetaSrcVectorSize, YDstVectorSize> + DeviceNormalizationSplitKImpl, // irregular size + DeviceNormalizationSplitKImpl, // irregular size + DeviceNormalizationSplitKImpl, // irregular size + DeviceNormalizationSplitKImpl, // irregular size + DeviceNormalizationSplitKImpl, // irregular size + DeviceNormalizationSplitKImpl, + DeviceNormalizationSplitKImpl, + DeviceNormalizationSplitKImpl, + DeviceNormalizationSplitKImpl, + DeviceNormalizationSplitKImpl, + DeviceNormalizationSplitKImpl, + DeviceNormalizationSplitKImpl, + DeviceNormalizationSplitKImpl, + DeviceNormalizationSplitKImpl, + DeviceNormalizationSplitKImpl, + DeviceNormalizationSplitKImpl, + DeviceNormalizationSplitKImpl, + DeviceNormalizationSplitKImpl, + DeviceNormalizationSplitKImpl + // clang-format on + >; + +template +using device_normalization_f16_f32_f32_f16_generic_instance = std::tuple< + // clang-format off + DeviceNormalizationImpl + // clang-format on + >; + } // namespace instance } // namespace device } // namespace tensor_operation diff --git a/library/src/tensor_operation_instance/gpu/pool3d_fwd/CMakeLists.txt b/library/src/tensor_operation_instance/gpu/pool3d_fwd/CMakeLists.txt new file mode 100644 index 0000000000000000000000000000000000000000..63bbe75465df3d41c7a6300d26abe8e6ce988723 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/pool3d_fwd/CMakeLists.txt @@ -0,0 +1,14 @@ +set(DEVICE_POOL3D_FWD_INSTANCES) +if(DTYPES MATCHES "fp16" OR NOT DEFINED DTYPES) + list(APPEND DEVICE_POOL3D_FWD_INSTANCES device_avg_pool3d_fwd_ndhwc_f16_instance.cpp + device_max_pool3d_fwd_ndhwc_f16_instance.cpp) +endif() +if(DTYPES MATCHES "bf16" OR NOT DEFINED DTYPES) + list(APPEND DEVICE_POOL3D_FWD_INSTANCES device_avg_pool3d_fwd_ndhwc_bf16_instance.cpp + device_max_pool3d_fwd_ndhwc_bf16_instance.cpp) +endif() +if(DTYPES MATCHES "fp32" OR NOT DEFINED DTYPES) + list(APPEND DEVICE_POOL3D_FWD_INSTANCES device_avg_pool3d_fwd_ndhwc_f32_instance.cpp + device_max_pool3d_fwd_ndhwc_f32_instance.cpp) +endif() +add_instance_library(device_pool3d_fwd_instance ${DEVICE_POOL3D_FWD_INSTANCES}) diff --git a/library/src/tensor_operation_instance/gpu/pool_fwd/device_avg_pool2d_fwd_nhwc_f16_instance.cpp b/library/src/tensor_operation_instance/gpu/pool3d_fwd/device_avg_pool3d_fwd_ndhwc_bf16_instance.cpp similarity index 51% rename from library/src/tensor_operation_instance/gpu/pool_fwd/device_avg_pool2d_fwd_nhwc_f16_instance.cpp rename to library/src/tensor_operation_instance/gpu/pool3d_fwd/device_avg_pool3d_fwd_ndhwc_bf16_instance.cpp index 38338ff9976cc28730bc36eee225529f31da949a..f10251699230b2dbb64d91ad966fc317b5fc0e5d 100644 --- a/library/src/tensor_operation_instance/gpu/pool_fwd/device_avg_pool2d_fwd_nhwc_f16_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/pool3d_fwd/device_avg_pool3d_fwd_ndhwc_bf16_instance.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #include "pool_fwd_instance_common.hpp" @@ -10,11 +10,13 @@ namespace instance { static constexpr auto ReduceOpId = ck::ReduceTensorOp::AVG; -void add_device_pool2d_fwd_nhwc_f16_instances( - std::vector>>& instances) +void add_device_pool3d_fwd_ndhwc_bf16_instances( + std::vector< + std::unique_ptr>>& + instances) { add_device_operation_instances( - instances, device_pool2d_fwd_nhwc_instances{}); + instances, device_pool3d_fwd_ndhwc_instances{}); } } // namespace instance diff --git a/library/src/tensor_operation_instance/gpu/pool_fwd/device_avg_pool3d_fwd_ndhwc_f16_instance.cpp b/library/src/tensor_operation_instance/gpu/pool3d_fwd/device_avg_pool3d_fwd_ndhwc_f16_instance.cpp similarity index 72% rename from library/src/tensor_operation_instance/gpu/pool_fwd/device_avg_pool3d_fwd_ndhwc_f16_instance.cpp rename to library/src/tensor_operation_instance/gpu/pool3d_fwd/device_avg_pool3d_fwd_ndhwc_f16_instance.cpp index 6fcb519a9b822fd0cdcc65e14ec4ecbb9d3f269a..4ebd50bae6bfd0f11cb56f5c03e196eceff4bd5a 100644 --- a/library/src/tensor_operation_instance/gpu/pool_fwd/device_avg_pool3d_fwd_ndhwc_f16_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/pool3d_fwd/device_avg_pool3d_fwd_ndhwc_f16_instance.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #include "pool_fwd_instance_common.hpp" @@ -11,7 +11,9 @@ namespace instance { static constexpr auto ReduceOpId = ck::ReduceTensorOp::AVG; void add_device_pool3d_fwd_ndhwc_f16_instances( - std::vector>>& instances) + std::vector< + std::unique_ptr>>& + instances) { add_device_operation_instances( instances, device_pool3d_fwd_ndhwc_instances{}); diff --git a/library/src/tensor_operation_instance/gpu/pool_fwd/device_avg_pool3d_fwd_ndhwc_f32_instance.cpp b/library/src/tensor_operation_instance/gpu/pool3d_fwd/device_avg_pool3d_fwd_ndhwc_f32_instance.cpp similarity index 72% rename from library/src/tensor_operation_instance/gpu/pool_fwd/device_avg_pool3d_fwd_ndhwc_f32_instance.cpp rename to library/src/tensor_operation_instance/gpu/pool3d_fwd/device_avg_pool3d_fwd_ndhwc_f32_instance.cpp index 67ffd470858175089475e459e869bb7cc6de9b30..dcb19110b2fb8b46baed73e7081c0227b0a39ee6 100644 --- a/library/src/tensor_operation_instance/gpu/pool_fwd/device_avg_pool3d_fwd_ndhwc_f32_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/pool3d_fwd/device_avg_pool3d_fwd_ndhwc_f32_instance.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #include "pool_fwd_instance_common.hpp" @@ -11,7 +11,9 @@ namespace instance { static constexpr auto ReduceOpId = ck::ReduceTensorOp::AVG; void add_device_pool3d_fwd_ndhwc_f32_instances( - std::vector>>& instances) + std::vector< + std::unique_ptr>>& + instances) { add_device_operation_instances( instances, device_pool3d_fwd_ndhwc_instances{}); diff --git a/library/src/tensor_operation_instance/gpu/pool3d_fwd/device_max_pool3d_fwd_ndhwc_bf16_instance.cpp b/library/src/tensor_operation_instance/gpu/pool3d_fwd/device_max_pool3d_fwd_ndhwc_bf16_instance.cpp new file mode 100644 index 0000000000000000000000000000000000000000..5dc504e1784df372358c7675ecd99c7b96ee74f2 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/pool3d_fwd/device_max_pool3d_fwd_ndhwc_bf16_instance.cpp @@ -0,0 +1,34 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#include "pool_fwd_instance_common.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +static constexpr auto ReduceOpId = ck::ReduceTensorOp::MAX; + +void add_device_pool3d_fwd_ndhwc_bf16_instances( + std::vector< + std::unique_ptr>>& + instances) +{ + add_device_operation_instances( + instances, device_pool3d_fwd_ndhwc_instances{}); +} + +void add_device_pool3d_fwd_ndhwc_index_bf16_instances( + std::vector< + std::unique_ptr>>& + instances) +{ + add_device_operation_instances( + instances, device_pool3d_fwd_ndhwc_instances{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/pool_fwd/device_max_pool3d_fwd_ndhwc_f16_instance.cpp b/library/src/tensor_operation_instance/gpu/pool3d_fwd/device_max_pool3d_fwd_ndhwc_f16_instance.cpp similarity index 68% rename from library/src/tensor_operation_instance/gpu/pool_fwd/device_max_pool3d_fwd_ndhwc_f16_instance.cpp rename to library/src/tensor_operation_instance/gpu/pool3d_fwd/device_max_pool3d_fwd_ndhwc_f16_instance.cpp index f3367b946098874d8533785c56c77bcfdcfd900b..46b16bd0053e4789d1c4e7a3ca59ec670188e37d 100644 --- a/library/src/tensor_operation_instance/gpu/pool_fwd/device_max_pool3d_fwd_ndhwc_f16_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/pool3d_fwd/device_max_pool3d_fwd_ndhwc_f16_instance.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #include "pool_fwd_instance_common.hpp" @@ -11,14 +11,18 @@ namespace instance { static constexpr auto ReduceOpId = ck::ReduceTensorOp::MAX; void add_device_pool3d_fwd_ndhwc_f16_instances( - std::vector>>& instances) + std::vector< + std::unique_ptr>>& + instances) { add_device_operation_instances( instances, device_pool3d_fwd_ndhwc_instances{}); } void add_device_pool3d_fwd_ndhwc_index_f16_instances( - std::vector>>& instances) + std::vector< + std::unique_ptr>>& + instances) { add_device_operation_instances( instances, device_pool3d_fwd_ndhwc_instances{}); diff --git a/library/src/tensor_operation_instance/gpu/pool_fwd/device_max_pool3d_fwd_ndhwc_f32_instance.cpp b/library/src/tensor_operation_instance/gpu/pool3d_fwd/device_max_pool3d_fwd_ndhwc_f32_instance.cpp similarity index 68% rename from library/src/tensor_operation_instance/gpu/pool_fwd/device_max_pool3d_fwd_ndhwc_f32_instance.cpp rename to library/src/tensor_operation_instance/gpu/pool3d_fwd/device_max_pool3d_fwd_ndhwc_f32_instance.cpp index 8477a884da7fda20301ab88f3bd3b4414f8dffeb..b4b0e74d287afdb88e2fbd30a34b8c4a2574357c 100644 --- a/library/src/tensor_operation_instance/gpu/pool_fwd/device_max_pool3d_fwd_ndhwc_f32_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/pool3d_fwd/device_max_pool3d_fwd_ndhwc_f32_instance.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #include "pool_fwd_instance_common.hpp" @@ -11,14 +11,18 @@ namespace instance { static constexpr auto ReduceOpId = ck::ReduceTensorOp::MAX; void add_device_pool3d_fwd_ndhwc_f32_instances( - std::vector>>& instances) + std::vector< + std::unique_ptr>>& + instances) { add_device_operation_instances( instances, device_pool3d_fwd_ndhwc_instances{}); } void add_device_pool3d_fwd_ndhwc_index_f32_instances( - std::vector>>& instances) + std::vector< + std::unique_ptr>>& + instances) { add_device_operation_instances( instances, device_pool3d_fwd_ndhwc_instances{}); diff --git a/library/src/tensor_operation_instance/gpu/pool3d_fwd/pool_fwd_instance_common.hpp b/library/src/tensor_operation_instance/gpu/pool3d_fwd/pool_fwd_instance_common.hpp new file mode 100644 index 0000000000000000000000000000000000000000..e8e781329531799239c39c13d72fa1c3c044884c --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/pool3d_fwd/pool_fwd_instance_common.hpp @@ -0,0 +1,42 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_pool2d_fwd_nhwc_nhwc.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_pool3d_fwd_ndhwc_ndhwc.hpp" +#include "ck/utility/data_type.hpp" + +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +using I32 = int32_t; +using F16 = ck::half_t; +using BF16 = ck::bhalf_t; +using F32 = float; +using NDHWC = ck::tensor_layout::convolution::NDHWC; + +template +using device_pool3d_fwd_ndhwc_instances = + // clang-format off + std::tuple < + DevicePool3dFwd_NDHWC_NDHWC, + DevicePool3dFwd_NDHWC_NDHWC, + DevicePool3dFwd_NDHWC_NDHWC + // clang-format on + >; + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/pool_fwd/CMakeLists.txt b/library/src/tensor_operation_instance/gpu/pool_fwd/CMakeLists.txt deleted file mode 100644 index 0d0f896c8d932384fbe8e3e84a7f1d0e6b2c388d..0000000000000000000000000000000000000000 --- a/library/src/tensor_operation_instance/gpu/pool_fwd/CMakeLists.txt +++ /dev/null @@ -1,10 +0,0 @@ -add_instance_library(device_pool_fwd_instance - device_avg_pool2d_fwd_nhwc_f16_instance.cpp - device_avg_pool2d_fwd_nhwc_f32_instance.cpp - device_avg_pool3d_fwd_ndhwc_f16_instance.cpp - device_avg_pool3d_fwd_ndhwc_f32_instance.cpp - device_max_pool2d_fwd_nhwc_f16_instance.cpp - device_max_pool2d_fwd_nhwc_f32_instance.cpp - device_max_pool3d_fwd_ndhwc_f16_instance.cpp - device_max_pool3d_fwd_ndhwc_f32_instance.cpp -) diff --git a/library/src/tensor_operation_instance/gpu/pool_fwd/device_avg_pool2d_fwd_nhwc_f32_instance.cpp b/library/src/tensor_operation_instance/gpu/pool_fwd/device_avg_pool2d_fwd_nhwc_f32_instance.cpp deleted file mode 100644 index 0f4a35deefcb40a1ef4cdfd90bb5e3f95cc412f6..0000000000000000000000000000000000000000 --- a/library/src/tensor_operation_instance/gpu/pool_fwd/device_avg_pool2d_fwd_nhwc_f32_instance.cpp +++ /dev/null @@ -1,23 +0,0 @@ -// SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. - -#include "pool_fwd_instance_common.hpp" - -namespace ck { -namespace tensor_operation { -namespace device { -namespace instance { - -static constexpr auto ReduceOpId = ck::ReduceTensorOp::AVG; - -void add_device_pool2d_fwd_nhwc_f32_instances( - std::vector>>& instances) -{ - add_device_operation_instances( - instances, device_pool2d_fwd_nhwc_instances{}); -} - -} // namespace instance -} // namespace device -} // namespace tensor_operation -} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/pool_fwd/device_max_pool2d_fwd_nhwc_f16_instance.cpp b/library/src/tensor_operation_instance/gpu/pool_fwd/device_max_pool2d_fwd_nhwc_f16_instance.cpp deleted file mode 100644 index a41cd0094bd90b5fe3ac9982b1cdd34481eec740..0000000000000000000000000000000000000000 --- a/library/src/tensor_operation_instance/gpu/pool_fwd/device_max_pool2d_fwd_nhwc_f16_instance.cpp +++ /dev/null @@ -1,30 +0,0 @@ -// SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. - -#include "pool_fwd_instance_common.hpp" - -namespace ck { -namespace tensor_operation { -namespace device { -namespace instance { - -static constexpr auto ReduceOpId = ck::ReduceTensorOp::MAX; - -void add_device_pool2d_fwd_nhwc_f16_instances( - std::vector>>& instances) -{ - add_device_operation_instances( - instances, device_pool2d_fwd_nhwc_instances{}); -} - -void add_device_pool2d_fwd_nhwc_index_f16_instances( - std::vector>>& instances) -{ - add_device_operation_instances( - instances, device_pool2d_fwd_nhwc_instances{}); -} - -} // namespace instance -} // namespace device -} // namespace tensor_operation -} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/pool_fwd/device_max_pool2d_fwd_nhwc_f32_instance.cpp b/library/src/tensor_operation_instance/gpu/pool_fwd/device_max_pool2d_fwd_nhwc_f32_instance.cpp deleted file mode 100644 index fa70569bab495090dd6290657727a6bd30162b81..0000000000000000000000000000000000000000 --- a/library/src/tensor_operation_instance/gpu/pool_fwd/device_max_pool2d_fwd_nhwc_f32_instance.cpp +++ /dev/null @@ -1,30 +0,0 @@ -// SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. - -#include "pool_fwd_instance_common.hpp" - -namespace ck { -namespace tensor_operation { -namespace device { -namespace instance { - -static constexpr auto ReduceOpId = ck::ReduceTensorOp::MAX; - -void add_device_pool2d_fwd_nhwc_f32_instances( - std::vector>>& instances) -{ - add_device_operation_instances( - instances, device_pool2d_fwd_nhwc_instances{}); -} - -void add_device_pool2d_fwd_nhwc_index_f32_instances( - std::vector>>& instances) -{ - add_device_operation_instances( - instances, device_pool2d_fwd_nhwc_instances{}); -} - -} // namespace instance -} // namespace device -} // namespace tensor_operation -} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/pool_fwd/pool_fwd_instance_common.hpp b/library/src/tensor_operation_instance/gpu/pool_fwd/pool_fwd_instance_common.hpp deleted file mode 100644 index cd508b55b0cfdc639ca7db204579ce2c3e7321df..0000000000000000000000000000000000000000 --- a/library/src/tensor_operation_instance/gpu/pool_fwd/pool_fwd_instance_common.hpp +++ /dev/null @@ -1,55 +0,0 @@ -// SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. - -#pragma once - -#include "ck/ck.hpp" -#include "ck/tensor_operation/gpu/device/impl/device_pool2d_fwd_nhwc_nhwc.hpp" -#include "ck/tensor_operation/gpu/device/impl/device_pool3d_fwd_ndhwc_ndhwc.hpp" -#include "ck/utility/data_type.hpp" - -#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" - -namespace ck { -namespace tensor_operation { -namespace device { -namespace instance { - -using I32 = int32_t; -using F16 = ck::half_t; -using F32 = float; - -template -using device_pool2d_fwd_nhwc_instances = - // clang-format off - std::tuple < - DevicePool2dFwd_Input_N_Hi_Wi_C_Output_N_Ho_Wo_C, - DevicePool2dFwd_Input_N_Hi_Wi_C_Output_N_Ho_Wo_C, - DevicePool2dFwd_Input_N_Hi_Wi_C_Output_N_Ho_Wo_C - // clang-format on - >; - -template -using device_pool3d_fwd_ndhwc_instances = - // clang-format off - std::tuple < - DevicePool3dFwd_Input_N_Di_Hi_Wi_C_Output_N_Do_Ho_Wo_C, - DevicePool3dFwd_Input_N_Di_Hi_Wi_C_Output_N_Do_Ho_Wo_C, - DevicePool3dFwd_Input_N_Di_Hi_Wi_C_Output_N_Do_Ho_Wo_C - // clang-format on - >; - -} // namespace instance -} // namespace device -} // namespace tensor_operation -} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/quantization/CMakeLists.txt b/library/src/tensor_operation_instance/gpu/quantization/CMakeLists.txt index db650dbfbdb41ad6a300c787a37fe6f8ea2719cc..00eb6ff1c13f66c93a681e9bcfe5aa8aedc28ae9 100644 --- a/library/src/tensor_operation_instance/gpu/quantization/CMakeLists.txt +++ b/library/src/tensor_operation_instance/gpu/quantization/CMakeLists.txt @@ -1,33 +1,26 @@ -set(CONV2D_PERLAYER_QUANT_SRC - conv2d_fwd/device_conv2d_dl_perlayer_quantization_int8_instance.cpp - conv2d_fwd/device_conv2d_xdl_perlayer_quantization_int8_instance.cpp -) - -set(CONV2D_PERCHANNEL_QUANT_SRC - conv2d_fwd/device_conv2d_dl_perchannel_quantization_int8_instance.cpp - conv2d_fwd/device_conv2d_xdl_perchannel_quantization_int8_instance.cpp -) - -set(CONV2D_BIAS_PERLAYER_QUANT_SRC - conv2d_fwd/device_conv2d_dl_bias_perlayer_quantization_int8_instance.cpp - conv2d_fwd/device_conv2d_xdl_bias_perlayer_quantization_int8_instance.cpp -) - -set(CONV2D_BIAS_PERCHANNEL_QUANT_SRC - conv2d_fwd/device_conv2d_dl_bias_perchannel_quantization_int8_instance.cpp - conv2d_fwd/device_conv2d_xdl_bias_perchannel_quantization_int8_instance.cpp -) +if(DTYPES MATCHES "int8" OR NOT DEFINED DTYPES) +set(CONV2D_PERLAYER_QUANT_SRC conv2d_fwd/device_conv2d_xdl_perlayer_quantization_int8_instance.cpp) +set(CONV2D_PERCHANNEL_QUANT_SRC conv2d_fwd/device_conv2d_xdl_perchannel_quantization_int8_instance.cpp) +set(CONV2D_BIAS_PERLAYER_QUANT_SRC conv2d_fwd/device_conv2d_xdl_bias_perlayer_quantization_int8_instance.cpp) +set(CONV2D_BIAS_PERCHANNEL_QUANT_SRC conv2d_fwd/device_conv2d_xdl_bias_perchannel_quantization_int8_instance.cpp) set(GEMM_QUANT_SRC - gemm/device_gemm_quantization_dl_c_shuffle_i8_i8_i8_km_kn_mn_instance.cpp - gemm/device_gemm_quantization_dl_c_shuffle_i8_i8_i8_km_nk_mn_instance.cpp - gemm/device_gemm_quantization_dl_c_shuffle_i8_i8_i8_mk_kn_mn_instance.cpp - gemm/device_gemm_quantization_dl_c_shuffle_i8_i8_i8_mk_nk_mn_instance.cpp gemm/device_gemm_quantization_xdl_c_shuffle_i8_i8_i8_km_kn_mn_instance.cpp gemm/device_gemm_quantization_xdl_c_shuffle_i8_i8_i8_km_nk_mn_instance.cpp gemm/device_gemm_quantization_xdl_c_shuffle_i8_i8_i8_mk_kn_mn_instance.cpp gemm/device_gemm_quantization_xdl_c_shuffle_i8_i8_i8_mk_nk_mn_instance.cpp ) +if(DL_KERNELS) + list(APPEND CONV2D_PERLAYER_QUANT_SRC conv2d_fwd/device_conv2d_dl_perlayer_quantization_int8_instance.cpp) + list(APPEND CONV2D_PERCHANNEL_QUANT_SRC conv2d_fwd/device_conv2d_dl_perchannel_quantization_int8_instance.cpp) + list(APPEND CONV2D_BIAS_PERLAYER_QUANT_SRC conv2d_fwd/device_conv2d_dl_bias_perlayer_quantization_int8_instance.cpp) + list(APPEND CONV2D_BIAS_PERCHANNEL_QUANT_SRC conv2d_fwd/device_conv2d_dl_bias_perchannel_quantization_int8_instance.cpp) + list(APPEND GEMM_QUANT_SRC + gemm/device_gemm_quantization_dl_c_shuffle_i8_i8_i8_km_kn_mn_instance.cpp + gemm/device_gemm_quantization_dl_c_shuffle_i8_i8_i8_km_nk_mn_instance.cpp + gemm/device_gemm_quantization_dl_c_shuffle_i8_i8_i8_mk_kn_mn_instance.cpp + gemm/device_gemm_quantization_dl_c_shuffle_i8_i8_i8_mk_nk_mn_instance.cpp) +endif() add_instance_library(device_quantization_instance ${CONV2D_PERLAYER_QUANT_SRC} @@ -36,3 +29,4 @@ add_instance_library(device_quantization_instance ${CONV2D_BIAS_PERCHANNEL_QUANT_SRC} ${GEMM_QUANT_SRC} ) +endif() \ No newline at end of file diff --git a/library/src/tensor_operation_instance/gpu/quantization/conv2d_fwd/conv2d_quantization_common.hpp b/library/src/tensor_operation_instance/gpu/quantization/conv2d_fwd/conv2d_quantization_common.hpp index 672cdba65d31799eee77622c12da7e405446d6ad..711314985a7cd2ba708138df597e37d1e874709f 100644 --- a/library/src/tensor_operation_instance/gpu/quantization/conv2d_fwd/conv2d_quantization_common.hpp +++ b/library/src/tensor_operation_instance/gpu/quantization/conv2d_fwd/conv2d_quantization_common.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/library/src/tensor_operation_instance/gpu/quantization/conv2d_fwd/device_conv2d_dl_bias_perchannel_quantization_int8_instance.cpp b/library/src/tensor_operation_instance/gpu/quantization/conv2d_fwd/device_conv2d_dl_bias_perchannel_quantization_int8_instance.cpp index d4b5484d8b05a9a9c7546f9eb0eb1eeea9414442..39c4f82fefd4beb2a0dc3629f0757f7f01037495 100644 --- a/library/src/tensor_operation_instance/gpu/quantization/conv2d_fwd/device_conv2d_dl_bias_perchannel_quantization_int8_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/quantization/conv2d_fwd/device_conv2d_dl_bias_perchannel_quantization_int8_instance.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #include "device_conv2d_dl_int8_instance.hpp" diff --git a/library/src/tensor_operation_instance/gpu/quantization/conv2d_fwd/device_conv2d_dl_bias_perlayer_quantization_int8_instance.cpp b/library/src/tensor_operation_instance/gpu/quantization/conv2d_fwd/device_conv2d_dl_bias_perlayer_quantization_int8_instance.cpp index 7db4b8d8695d655dddad18fe7b63d89c5d6b064b..92e73eb2ee45895c4277eb46f2a5ec94116fc1e2 100644 --- a/library/src/tensor_operation_instance/gpu/quantization/conv2d_fwd/device_conv2d_dl_bias_perlayer_quantization_int8_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/quantization/conv2d_fwd/device_conv2d_dl_bias_perlayer_quantization_int8_instance.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #include "device_conv2d_dl_int8_instance.hpp" diff --git a/library/src/tensor_operation_instance/gpu/quantization/conv2d_fwd/device_conv2d_dl_int8_instance.hpp b/library/src/tensor_operation_instance/gpu/quantization/conv2d_fwd/device_conv2d_dl_int8_instance.hpp index 7eefbe0381c19cba139399cff386a4841768e64c..2ec37c8413e7ca926996ed8165c7e23730e77769 100644 --- a/library/src/tensor_operation_instance/gpu/quantization/conv2d_fwd/device_conv2d_dl_int8_instance.hpp +++ b/library/src/tensor_operation_instance/gpu/quantization/conv2d_fwd/device_conv2d_dl_int8_instance.hpp @@ -1,10 +1,10 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #pragma once #include "conv2d_quantization_common.hpp" -#include "ck/tensor_operation/gpu/device/device_grouped_conv_fwd_dl_multiple_d_nhwc_kyxc_nhwk.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_dl_multiple_d_nhwc_kyxc_nhwk.hpp" namespace ck { namespace tensor_operation { diff --git a/library/src/tensor_operation_instance/gpu/quantization/conv2d_fwd/device_conv2d_dl_perchannel_quantization_int8_instance.cpp b/library/src/tensor_operation_instance/gpu/quantization/conv2d_fwd/device_conv2d_dl_perchannel_quantization_int8_instance.cpp index c8f5f7042c92c5aaea7ab91cb27c7d2b243a2698..1d8b58fd18a49b777158c266459925a8ef67079b 100644 --- a/library/src/tensor_operation_instance/gpu/quantization/conv2d_fwd/device_conv2d_dl_perchannel_quantization_int8_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/quantization/conv2d_fwd/device_conv2d_dl_perchannel_quantization_int8_instance.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #include "device_conv2d_dl_int8_instance.hpp" diff --git a/library/src/tensor_operation_instance/gpu/quantization/conv2d_fwd/device_conv2d_dl_perlayer_quantization_int8_instance.cpp b/library/src/tensor_operation_instance/gpu/quantization/conv2d_fwd/device_conv2d_dl_perlayer_quantization_int8_instance.cpp index d7f7384ff78262cf21680ddc1c03238194540499..62826d0d2b3bb541cb70ed1fc7f70f6c55a431c1 100644 --- a/library/src/tensor_operation_instance/gpu/quantization/conv2d_fwd/device_conv2d_dl_perlayer_quantization_int8_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/quantization/conv2d_fwd/device_conv2d_dl_perlayer_quantization_int8_instance.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #include "device_conv2d_dl_int8_instance.hpp" diff --git a/library/src/tensor_operation_instance/gpu/quantization/conv2d_fwd/device_conv2d_xdl_bias_perchannel_quantization_int8_instance.cpp b/library/src/tensor_operation_instance/gpu/quantization/conv2d_fwd/device_conv2d_xdl_bias_perchannel_quantization_int8_instance.cpp index 658aa8370a43aa27d9ca60bcdf303ede664fbb51..99877c32dba05ea6386bb2f94d888e7fe6d95279 100644 --- a/library/src/tensor_operation_instance/gpu/quantization/conv2d_fwd/device_conv2d_xdl_bias_perchannel_quantization_int8_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/quantization/conv2d_fwd/device_conv2d_xdl_bias_perchannel_quantization_int8_instance.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #include "device_conv2d_xdl_int8_instance.hpp" diff --git a/library/src/tensor_operation_instance/gpu/quantization/conv2d_fwd/device_conv2d_xdl_bias_perlayer_quantization_int8_instance.cpp b/library/src/tensor_operation_instance/gpu/quantization/conv2d_fwd/device_conv2d_xdl_bias_perlayer_quantization_int8_instance.cpp index 7102e9b259cabb4b45783d915ac43a4f5a34f92b..50ccc69f458fc9a46bbe5e08510e3da1174a90cc 100644 --- a/library/src/tensor_operation_instance/gpu/quantization/conv2d_fwd/device_conv2d_xdl_bias_perlayer_quantization_int8_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/quantization/conv2d_fwd/device_conv2d_xdl_bias_perlayer_quantization_int8_instance.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #include "device_conv2d_xdl_int8_instance.hpp" diff --git a/library/src/tensor_operation_instance/gpu/quantization/conv2d_fwd/device_conv2d_xdl_int8_instance.hpp b/library/src/tensor_operation_instance/gpu/quantization/conv2d_fwd/device_conv2d_xdl_int8_instance.hpp index 90f8791aa94ca22a0be19f0792cc7c6879a6ce5f..caced6c950a7872714b76aadee100e3a524a03d8 100644 --- a/library/src/tensor_operation_instance/gpu/quantization/conv2d_fwd/device_conv2d_xdl_int8_instance.hpp +++ b/library/src/tensor_operation_instance/gpu/quantization/conv2d_fwd/device_conv2d_xdl_int8_instance.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/library/src/tensor_operation_instance/gpu/quantization/conv2d_fwd/device_conv2d_xdl_perchannel_quantization_int8_instance.cpp b/library/src/tensor_operation_instance/gpu/quantization/conv2d_fwd/device_conv2d_xdl_perchannel_quantization_int8_instance.cpp index 9d6937708550a2d7b5cf27a1a4c822366b01d6a8..526fe73463cba7ac977b597c2279488aaa9afe6b 100644 --- a/library/src/tensor_operation_instance/gpu/quantization/conv2d_fwd/device_conv2d_xdl_perchannel_quantization_int8_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/quantization/conv2d_fwd/device_conv2d_xdl_perchannel_quantization_int8_instance.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #include "device_conv2d_xdl_int8_instance.hpp" diff --git a/library/src/tensor_operation_instance/gpu/quantization/conv2d_fwd/device_conv2d_xdl_perlayer_quantization_int8_instance.cpp b/library/src/tensor_operation_instance/gpu/quantization/conv2d_fwd/device_conv2d_xdl_perlayer_quantization_int8_instance.cpp index d6f87335b2d3c21996be54adbffe39b36e0a3e97..d1d5a66bb949b391909426aa13ac4bc4d9e91be6 100644 --- a/library/src/tensor_operation_instance/gpu/quantization/conv2d_fwd/device_conv2d_xdl_perlayer_quantization_int8_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/quantization/conv2d_fwd/device_conv2d_xdl_perlayer_quantization_int8_instance.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #include "device_conv2d_xdl_int8_instance.hpp" diff --git a/library/src/tensor_operation_instance/gpu/quantization/gemm/device_gemm_quantization_dl_c_shuffle_i8_i8_i8_instance.hpp b/library/src/tensor_operation_instance/gpu/quantization/gemm/device_gemm_quantization_dl_c_shuffle_i8_i8_i8_instance.hpp index 9cad8d4c8e8c14b2fd15a802a102c4de53f5afef..0ec498a441d26c37cb5e449213f9b1048c36c9c5 100644 --- a/library/src/tensor_operation_instance/gpu/quantization/gemm/device_gemm_quantization_dl_c_shuffle_i8_i8_i8_instance.hpp +++ b/library/src/tensor_operation_instance/gpu/quantization/gemm/device_gemm_quantization_dl_c_shuffle_i8_i8_i8_instance.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #include "gemm_quantization_common.hpp" #include "ck/tensor_operation/gpu/device/impl/device_gemm_multiple_d_dl.hpp" diff --git a/library/src/tensor_operation_instance/gpu/quantization/gemm/device_gemm_quantization_dl_c_shuffle_i8_i8_i8_km_kn_mn_instance.cpp b/library/src/tensor_operation_instance/gpu/quantization/gemm/device_gemm_quantization_dl_c_shuffle_i8_i8_i8_km_kn_mn_instance.cpp index ffe1efb80b5721230bcb059412a18d1c4464f3af..4c02f8179caf2a7bf5ee6a46111f0b9d610e7667 100644 --- a/library/src/tensor_operation_instance/gpu/quantization/gemm/device_gemm_quantization_dl_c_shuffle_i8_i8_i8_km_kn_mn_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/quantization/gemm/device_gemm_quantization_dl_c_shuffle_i8_i8_i8_km_kn_mn_instance.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #include "device_gemm_quantization_dl_c_shuffle_i8_i8_i8_instance.hpp" diff --git a/library/src/tensor_operation_instance/gpu/quantization/gemm/device_gemm_quantization_dl_c_shuffle_i8_i8_i8_km_nk_mn_instance.cpp b/library/src/tensor_operation_instance/gpu/quantization/gemm/device_gemm_quantization_dl_c_shuffle_i8_i8_i8_km_nk_mn_instance.cpp index 7f24e5677c482ac881ddebd8d0aeab9d1aca0b39..2c635e96d228fb4283fe518dbf010c66fc9a4272 100644 --- a/library/src/tensor_operation_instance/gpu/quantization/gemm/device_gemm_quantization_dl_c_shuffle_i8_i8_i8_km_nk_mn_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/quantization/gemm/device_gemm_quantization_dl_c_shuffle_i8_i8_i8_km_nk_mn_instance.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #include "device_gemm_quantization_dl_c_shuffle_i8_i8_i8_instance.hpp" diff --git a/library/src/tensor_operation_instance/gpu/quantization/gemm/device_gemm_quantization_dl_c_shuffle_i8_i8_i8_mk_kn_mn_instance.cpp b/library/src/tensor_operation_instance/gpu/quantization/gemm/device_gemm_quantization_dl_c_shuffle_i8_i8_i8_mk_kn_mn_instance.cpp index 06e66cfe037696ff9f3198bf5308579a3e18fea8..c5a52f2cf25137fbbfac20d6451db1c67c1ea660 100644 --- a/library/src/tensor_operation_instance/gpu/quantization/gemm/device_gemm_quantization_dl_c_shuffle_i8_i8_i8_mk_kn_mn_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/quantization/gemm/device_gemm_quantization_dl_c_shuffle_i8_i8_i8_mk_kn_mn_instance.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #include "device_gemm_quantization_dl_c_shuffle_i8_i8_i8_instance.hpp" diff --git a/library/src/tensor_operation_instance/gpu/quantization/gemm/device_gemm_quantization_dl_c_shuffle_i8_i8_i8_mk_nk_mn_instance.cpp b/library/src/tensor_operation_instance/gpu/quantization/gemm/device_gemm_quantization_dl_c_shuffle_i8_i8_i8_mk_nk_mn_instance.cpp index 16635d1e93de1b80933dddbc38017b2716b56132..fe4a0b4d68b4ed659aef0e6ab23d4a443cdfde72 100644 --- a/library/src/tensor_operation_instance/gpu/quantization/gemm/device_gemm_quantization_dl_c_shuffle_i8_i8_i8_mk_nk_mn_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/quantization/gemm/device_gemm_quantization_dl_c_shuffle_i8_i8_i8_mk_nk_mn_instance.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #include "device_gemm_quantization_dl_c_shuffle_i8_i8_i8_instance.hpp" diff --git a/library/src/tensor_operation_instance/gpu/quantization/gemm/device_gemm_quantization_xdl_c_shuffle_i8_i8_i8_instance.hpp b/library/src/tensor_operation_instance/gpu/quantization/gemm/device_gemm_quantization_xdl_c_shuffle_i8_i8_i8_instance.hpp index dfb8dc29b49c8501671e3b17c00e7829c9941deb..d1b18ac4929bb114ab870a735b6b182e7a4e1343 100644 --- a/library/src/tensor_operation_instance/gpu/quantization/gemm/device_gemm_quantization_xdl_c_shuffle_i8_i8_i8_instance.hpp +++ b/library/src/tensor_operation_instance/gpu/quantization/gemm/device_gemm_quantization_xdl_c_shuffle_i8_i8_i8_instance.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #include "gemm_quantization_common.hpp" #include "ck/tensor_operation/gpu/device/impl/device_gemm_multiple_d_xdl_cshuffle.hpp" diff --git a/library/src/tensor_operation_instance/gpu/quantization/gemm/device_gemm_quantization_xdl_c_shuffle_i8_i8_i8_km_kn_mn_instance.cpp b/library/src/tensor_operation_instance/gpu/quantization/gemm/device_gemm_quantization_xdl_c_shuffle_i8_i8_i8_km_kn_mn_instance.cpp index c153cdf9ea9e2f02bcc28b1c50f707cc683e0211..e3163c32ed1cb9704e00f5d46e82cfacb187aa60 100644 --- a/library/src/tensor_operation_instance/gpu/quantization/gemm/device_gemm_quantization_xdl_c_shuffle_i8_i8_i8_km_kn_mn_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/quantization/gemm/device_gemm_quantization_xdl_c_shuffle_i8_i8_i8_km_kn_mn_instance.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #include "device_gemm_quantization_xdl_c_shuffle_i8_i8_i8_instance.hpp" diff --git a/library/src/tensor_operation_instance/gpu/quantization/gemm/device_gemm_quantization_xdl_c_shuffle_i8_i8_i8_km_nk_mn_instance.cpp b/library/src/tensor_operation_instance/gpu/quantization/gemm/device_gemm_quantization_xdl_c_shuffle_i8_i8_i8_km_nk_mn_instance.cpp index f6cd32026f92ead80af192fb90c2ce794c26c398..f6a0163aba346bc731f3241f0426306e5eeb43f3 100644 --- a/library/src/tensor_operation_instance/gpu/quantization/gemm/device_gemm_quantization_xdl_c_shuffle_i8_i8_i8_km_nk_mn_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/quantization/gemm/device_gemm_quantization_xdl_c_shuffle_i8_i8_i8_km_nk_mn_instance.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #include "device_gemm_quantization_xdl_c_shuffle_i8_i8_i8_instance.hpp" diff --git a/library/src/tensor_operation_instance/gpu/quantization/gemm/device_gemm_quantization_xdl_c_shuffle_i8_i8_i8_mk_kn_mn_instance.cpp b/library/src/tensor_operation_instance/gpu/quantization/gemm/device_gemm_quantization_xdl_c_shuffle_i8_i8_i8_mk_kn_mn_instance.cpp index 45fbacc334d0974644641e5d43edde90f7aa9c26..0cfffbed70fdc95b6161c06a634bf379cc7c2651 100644 --- a/library/src/tensor_operation_instance/gpu/quantization/gemm/device_gemm_quantization_xdl_c_shuffle_i8_i8_i8_mk_kn_mn_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/quantization/gemm/device_gemm_quantization_xdl_c_shuffle_i8_i8_i8_mk_kn_mn_instance.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #include "device_gemm_quantization_xdl_c_shuffle_i8_i8_i8_instance.hpp" diff --git a/library/src/tensor_operation_instance/gpu/quantization/gemm/device_gemm_quantization_xdl_c_shuffle_i8_i8_i8_mk_nk_mn_instance.cpp b/library/src/tensor_operation_instance/gpu/quantization/gemm/device_gemm_quantization_xdl_c_shuffle_i8_i8_i8_mk_nk_mn_instance.cpp index 257633fe19a1360d44b8fa9d768918e106210909..289898257b76ab48c44e06259dd067096f8489cd 100644 --- a/library/src/tensor_operation_instance/gpu/quantization/gemm/device_gemm_quantization_xdl_c_shuffle_i8_i8_i8_mk_nk_mn_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/quantization/gemm/device_gemm_quantization_xdl_c_shuffle_i8_i8_i8_mk_nk_mn_instance.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #include "device_gemm_quantization_xdl_c_shuffle_i8_i8_i8_instance.hpp" diff --git a/library/src/tensor_operation_instance/gpu/quantization/gemm/gemm_quantization_common.hpp b/library/src/tensor_operation_instance/gpu/quantization/gemm/gemm_quantization_common.hpp index 213f42b91b6365bdc5e346c060ca33a8c2ebb71a..e7c2500fef2fac9e90639301c60592dcb8978ccd 100644 --- a/library/src/tensor_operation_instance/gpu/quantization/gemm/gemm_quantization_common.hpp +++ b/library/src/tensor_operation_instance/gpu/quantization/gemm/gemm_quantization_common.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/library/src/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise_b16_f32_b16_add.cpp b/library/src/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise_b16_f32_b16_add.cpp index cf46059a0d2cb9f397209124f15f808bc51fd165..4771b4aa5de2ccfb5b7f9a6205f64594d37bdabd 100644 --- a/library/src/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise_b16_f32_b16_add.cpp +++ b/library/src/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise_b16_f32_b16_add.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #include "ck/utility/reduction_enums.hpp" #include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise.hpp" diff --git a/library/src/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise_b16_f32_b16_amax.cpp b/library/src/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise_b16_f32_b16_amax.cpp index 0043b198440c2a1da45b10cd81197b493d51b077..f684dabc1fcd60bd6522a9865f18ac39dcdc3d0e 100644 --- a/library/src/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise_b16_f32_b16_amax.cpp +++ b/library/src/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise_b16_f32_b16_amax.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #include "ck/utility/reduction_enums.hpp" #include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise.hpp" diff --git a/library/src/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise_b16_f32_b16_avg.cpp b/library/src/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise_b16_f32_b16_avg.cpp index 6f702ddf1ffee3a684dd66c124a3a2d7ef1f8654..86e2ca7c47cb8f2a318c745b19e0fcec058c7318 100644 --- a/library/src/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise_b16_f32_b16_avg.cpp +++ b/library/src/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise_b16_f32_b16_avg.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #include "ck/utility/reduction_enums.hpp" #include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise.hpp" diff --git a/library/src/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise_b16_f32_b16_max.cpp b/library/src/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise_b16_f32_b16_max.cpp index d1f70dc99e87606c0a388b57a181aa12c6ff26a5..97d09ff834fc46a65b8f25182238349f70bce818 100644 --- a/library/src/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise_b16_f32_b16_max.cpp +++ b/library/src/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise_b16_f32_b16_max.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #include "ck/utility/reduction_enums.hpp" #include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise.hpp" diff --git a/library/src/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise_b16_f32_b16_min.cpp b/library/src/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise_b16_f32_b16_min.cpp index a957981a55233f7f2c9e5039bfb4705cf2322850..87187f493acf8b8247e8877b369edfc106a6aaa1 100644 --- a/library/src/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise_b16_f32_b16_min.cpp +++ b/library/src/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise_b16_f32_b16_min.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #include "ck/utility/reduction_enums.hpp" #include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise.hpp" diff --git a/library/src/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise_b16_f32_b16_norm2.cpp b/library/src/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise_b16_f32_b16_norm2.cpp index 550a9cd76c6db474de2b78f323a5b8feed936bd7..d3f5cd8656940f1adad49d0fd38725a5e04167ea 100644 --- a/library/src/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise_b16_f32_b16_norm2.cpp +++ b/library/src/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise_b16_f32_b16_norm2.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #include "ck/utility/reduction_enums.hpp" #include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise.hpp" diff --git a/library/src/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise_f16_f16_f16_amax.cpp b/library/src/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise_f16_f16_f16_amax.cpp index 58cb6ee34853d2670ee89b2fde99e42cb3160819..fb4a50d5e3539c3c2db9e799ea2c824a59c90c7c 100644 --- a/library/src/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise_f16_f16_f16_amax.cpp +++ b/library/src/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise_f16_f16_f16_amax.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #include "ck/utility/reduction_enums.hpp" #include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise.hpp" diff --git a/library/src/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise_f16_f16_f16_max.cpp b/library/src/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise_f16_f16_f16_max.cpp index 1ac5e79bc16cfa587ae05072a72c8b43de08ba2a..f74d22f7acadb2ab97a665adfb6eb079ed34cb64 100644 --- a/library/src/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise_f16_f16_f16_max.cpp +++ b/library/src/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise_f16_f16_f16_max.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #include "ck/utility/reduction_enums.hpp" #include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise.hpp" diff --git a/library/src/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise_f16_f16_f16_min.cpp b/library/src/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise_f16_f16_f16_min.cpp index b1e1a06800ec06e1fa0ab388ae63ca20d29bbe0a..0982455add502da7eb5f9365f4b128673846db82 100644 --- a/library/src/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise_f16_f16_f16_min.cpp +++ b/library/src/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise_f16_f16_f16_min.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #include "ck/utility/reduction_enums.hpp" #include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise.hpp" diff --git a/library/src/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise_f16_f32_f16_add.cpp b/library/src/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise_f16_f32_f16_add.cpp index 1a15b32d23fbe533322249132a8b759745d6362c..c6f1c3a6338115a42efc02311228af8b488c4a8d 100644 --- a/library/src/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise_f16_f32_f16_add.cpp +++ b/library/src/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise_f16_f32_f16_add.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #include "ck/utility/reduction_enums.hpp" #include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise.hpp" diff --git a/library/src/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise_f16_f32_f16_avg.cpp b/library/src/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise_f16_f32_f16_avg.cpp index 119f384b4ebcc8b446d0ace7c8bd65a86528e13f..db9f52a13fe436252d57644b9d76db41170b042b 100644 --- a/library/src/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise_f16_f32_f16_avg.cpp +++ b/library/src/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise_f16_f32_f16_avg.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #include "ck/utility/reduction_enums.hpp" #include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise.hpp" diff --git a/library/src/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise_f16_f32_f16_norm2.cpp b/library/src/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise_f16_f32_f16_norm2.cpp index 3f1bd86b8b8f688b3e341c24535d45d99a0bc941..1410f51bea100f3aa56b369d8442d3c81fc79ebe 100644 --- a/library/src/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise_f16_f32_f16_norm2.cpp +++ b/library/src/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise_f16_f32_f16_norm2.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #include "ck/utility/reduction_enums.hpp" #include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise.hpp" diff --git a/library/src/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise_f32_f32_f32_add.cpp b/library/src/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise_f32_f32_f32_add.cpp index b507f0d1f90621800e73a5d23fbbfa88f1443684..707a655703503001118da5d2e28e53f24ecc135c 100644 --- a/library/src/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise_f32_f32_f32_add.cpp +++ b/library/src/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise_f32_f32_f32_add.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #include "ck/utility/reduction_enums.hpp" #include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise.hpp" diff --git a/library/src/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise_f32_f32_f32_amax.cpp b/library/src/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise_f32_f32_f32_amax.cpp index 04d0ea2e8b86c78d0348e0e0878a6a21ccbf1004..17f45c3327c63296291aef9132461f5bb56074b2 100644 --- a/library/src/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise_f32_f32_f32_amax.cpp +++ b/library/src/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise_f32_f32_f32_amax.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #include "ck/utility/reduction_enums.hpp" #include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise.hpp" diff --git a/library/src/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise_f32_f32_f32_avg.cpp b/library/src/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise_f32_f32_f32_avg.cpp index 3de561f2b615cc1ac18962eabacf583afb8c11aa..d3753d01c4e65dc2344ff4d8199bd81d674f5c96 100644 --- a/library/src/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise_f32_f32_f32_avg.cpp +++ b/library/src/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise_f32_f32_f32_avg.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #include "ck/utility/reduction_enums.hpp" #include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise.hpp" diff --git a/library/src/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise_f32_f32_f32_max.cpp b/library/src/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise_f32_f32_f32_max.cpp index 3f45b03134a85322874019f7e7dcdd102a4d5013..83fad7826a070ef91abcc33c68b1230b410fb8f5 100644 --- a/library/src/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise_f32_f32_f32_max.cpp +++ b/library/src/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise_f32_f32_f32_max.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #include "ck/utility/reduction_enums.hpp" #include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise.hpp" diff --git a/library/src/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise_f32_f32_f32_min.cpp b/library/src/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise_f32_f32_f32_min.cpp index 76851d9b726c4f75ec68e2206d7fc0d92c91e38e..c627f686363a9a12e251797b45060d56d2947d33 100644 --- a/library/src/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise_f32_f32_f32_min.cpp +++ b/library/src/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise_f32_f32_f32_min.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #include "ck/utility/reduction_enums.hpp" #include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise.hpp" diff --git a/library/src/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise_f32_f32_f32_norm2.cpp b/library/src/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise_f32_f32_f32_norm2.cpp index 9cef019320b8211a9e1228e322517eb0f28e65fd..09686d98ffcacb07f5f9244c196daeec993780dc 100644 --- a/library/src/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise_f32_f32_f32_norm2.cpp +++ b/library/src/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise_f32_f32_f32_norm2.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #include "ck/utility/reduction_enums.hpp" #include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise.hpp" diff --git a/library/src/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise_f32_f64_f32_add.cpp b/library/src/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise_f32_f64_f32_add.cpp index ce73ec47e3c2078e9c7629107339e1db17c4934f..44519032f31cba47f9606aee66dce5b989685931 100644 --- a/library/src/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise_f32_f64_f32_add.cpp +++ b/library/src/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise_f32_f64_f32_add.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #include "ck/utility/reduction_enums.hpp" #include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise.hpp" diff --git a/library/src/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise_f32_f64_f32_avg.cpp b/library/src/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise_f32_f64_f32_avg.cpp index ed6091f92488263f3e3bbaa1804102f674cba288..9e14d80e21ae1956595eca87ab6ab859c42c52f6 100644 --- a/library/src/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise_f32_f64_f32_avg.cpp +++ b/library/src/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise_f32_f64_f32_avg.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #include "ck/utility/reduction_enums.hpp" #include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise.hpp" diff --git a/library/src/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise_f32_f64_f32_norm2.cpp b/library/src/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise_f32_f64_f32_norm2.cpp index 4c8375de169b3f94eb30696162248fda78b375bb..ab217f1160481e0155cc3f0e6e62dc21f1f70fe2 100644 --- a/library/src/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise_f32_f64_f32_norm2.cpp +++ b/library/src/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise_f32_f64_f32_norm2.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #include "ck/utility/reduction_enums.hpp" #include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise.hpp" diff --git a/library/src/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise_f64_f64_f64_add.cpp b/library/src/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise_f64_f64_f64_add.cpp index 0fa93ab68879f310bac2698a241a23017e4b4064..3f611da7e0da0d7f61b3af38f18ac29b8dc5b2a0 100644 --- a/library/src/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise_f64_f64_f64_add.cpp +++ b/library/src/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise_f64_f64_f64_add.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #include "ck/utility/reduction_enums.hpp" #include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise.hpp" diff --git a/library/src/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise_f64_f64_f64_amax.cpp b/library/src/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise_f64_f64_f64_amax.cpp index 821eec1751a8b7c3e11884e53749da844d2a3fec..8695db4004081ee0197eebd623cf539286236d3e 100644 --- a/library/src/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise_f64_f64_f64_amax.cpp +++ b/library/src/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise_f64_f64_f64_amax.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #include "ck/utility/reduction_enums.hpp" #include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise.hpp" diff --git a/library/src/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise_f64_f64_f64_avg.cpp b/library/src/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise_f64_f64_f64_avg.cpp index 0305b4945f14f91398492aa53cf0795043a19c6c..8d15fda7869d85daa457ff5cfb69bfcbb719cf9b 100644 --- a/library/src/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise_f64_f64_f64_avg.cpp +++ b/library/src/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise_f64_f64_f64_avg.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #include "ck/utility/reduction_enums.hpp" #include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise.hpp" diff --git a/library/src/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise_f64_f64_f64_max.cpp b/library/src/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise_f64_f64_f64_max.cpp index 1bda0bcc71ad37daab5a1e36a188fffd5c10fb7a..c161858b1f161694485d5531485995490949112d 100644 --- a/library/src/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise_f64_f64_f64_max.cpp +++ b/library/src/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise_f64_f64_f64_max.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #include "ck/utility/reduction_enums.hpp" #include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise.hpp" diff --git a/library/src/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise_f64_f64_f64_min.cpp b/library/src/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise_f64_f64_f64_min.cpp index 7f8018a04eb1b7cb9fed18b18b61c6dc1dd586f3..e6b87fdead361de49fbb28e5a709b2d5fea812b2 100644 --- a/library/src/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise_f64_f64_f64_min.cpp +++ b/library/src/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise_f64_f64_f64_min.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #include "ck/utility/reduction_enums.hpp" #include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise.hpp" diff --git a/library/src/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise_f64_f64_f64_norm2.cpp b/library/src/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise_f64_f64_f64_norm2.cpp index 887a89cc2babe127e32a8f6c7b6b42b2645ed614..c2e4acc8d9338824df7811d2434474ceddf5ec6d 100644 --- a/library/src/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise_f64_f64_f64_norm2.cpp +++ b/library/src/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise_f64_f64_f64_norm2.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #include "ck/utility/reduction_enums.hpp" #include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise.hpp" diff --git a/library/src/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise_i8_i32_i8_add.cpp b/library/src/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise_i8_i32_i8_add.cpp index 0cc810363d4ca4191add9898f645327d27520203..e5a5e07e176b68ed8eca614a52a432efff1ca9d7 100644 --- a/library/src/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise_i8_i32_i8_add.cpp +++ b/library/src/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise_i8_i32_i8_add.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #include "ck/utility/reduction_enums.hpp" #include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise.hpp" diff --git a/library/src/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise_i8_i32_i8_avg.cpp b/library/src/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise_i8_i32_i8_avg.cpp index 4c825a9f1baa5c84ab746777c0f4b7839b9b6c86..22c62033454a9525f87fc7b19f8ab0f6186a3cda 100644 --- a/library/src/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise_i8_i32_i8_avg.cpp +++ b/library/src/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise_i8_i32_i8_avg.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #include "ck/utility/reduction_enums.hpp" #include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise.hpp" diff --git a/library/src/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise_i8_i8_i8_amax.cpp b/library/src/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise_i8_i8_i8_amax.cpp index bf26913fd31c54d2d3d5e4459ed30179be3f17e0..1890fb78cd7c7fd387f6f9d3290009b54ca3ba96 100644 --- a/library/src/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise_i8_i8_i8_amax.cpp +++ b/library/src/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise_i8_i8_i8_amax.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #include "ck/utility/reduction_enums.hpp" #include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise.hpp" diff --git a/library/src/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise_i8_i8_i8_max.cpp b/library/src/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise_i8_i8_i8_max.cpp index 629299c7b112268ab166bacc0350aa53646393ad..dc5caf297bc971950c2424496ec1c2a2fe2c2a91 100644 --- a/library/src/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise_i8_i8_i8_max.cpp +++ b/library/src/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise_i8_i8_i8_max.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #include "ck/utility/reduction_enums.hpp" #include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise.hpp" diff --git a/library/src/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise_i8_i8_i8_min.cpp b/library/src/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise_i8_i8_i8_min.cpp index 9a086344985249eec8ca08a7db6a3d6a6859a033..02bb0c3b1d9d5e8db84a0338718c66a4cc7e66c9 100644 --- a/library/src/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise_i8_i8_i8_min.cpp +++ b/library/src/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise_i8_i8_i8_min.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #include "ck/utility/reduction_enums.hpp" #include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise.hpp" diff --git a/library/src/tensor_operation_instance/gpu/reduce/device_reduce_instance_multiblock_atomic_add_b16_f32_f32_add.cpp b/library/src/tensor_operation_instance/gpu/reduce/device_reduce_instance_multiblock_atomic_add_b16_f32_f32_add.cpp index 6dc925bd6f3f790de62b56598f6d813ee4920f7b..0b0a40303805ccfb4d9ea93e4b1295cc7b2edb6e 100644 --- a/library/src/tensor_operation_instance/gpu/reduce/device_reduce_instance_multiblock_atomic_add_b16_f32_f32_add.cpp +++ b/library/src/tensor_operation_instance/gpu/reduce/device_reduce_instance_multiblock_atomic_add_b16_f32_f32_add.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #include "ck/utility/reduction_enums.hpp" #include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_multiblock_atomic_add.hpp" diff --git a/library/src/tensor_operation_instance/gpu/reduce/device_reduce_instance_multiblock_atomic_add_b16_f32_f32_avg.cpp b/library/src/tensor_operation_instance/gpu/reduce/device_reduce_instance_multiblock_atomic_add_b16_f32_f32_avg.cpp index 470d68d37246c5f5b549068df36c1b29f5b42e74..fcf95aab7a37aa9c79ef6e22717afebb24052406 100644 --- a/library/src/tensor_operation_instance/gpu/reduce/device_reduce_instance_multiblock_atomic_add_b16_f32_f32_avg.cpp +++ b/library/src/tensor_operation_instance/gpu/reduce/device_reduce_instance_multiblock_atomic_add_b16_f32_f32_avg.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #include "ck/utility/reduction_enums.hpp" #include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_multiblock_atomic_add.hpp" diff --git a/library/src/tensor_operation_instance/gpu/reduce/device_reduce_instance_multiblock_atomic_add_f16_f32_f32_add.cpp b/library/src/tensor_operation_instance/gpu/reduce/device_reduce_instance_multiblock_atomic_add_f16_f32_f32_add.cpp index 39303ab58002eb1a538a51a621822a1f715c5a9a..5151d32a58e9c321985faf6f8cb7508e5cc4ad5f 100644 --- a/library/src/tensor_operation_instance/gpu/reduce/device_reduce_instance_multiblock_atomic_add_f16_f32_f32_add.cpp +++ b/library/src/tensor_operation_instance/gpu/reduce/device_reduce_instance_multiblock_atomic_add_f16_f32_f32_add.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #include "ck/utility/reduction_enums.hpp" #include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_multiblock_atomic_add.hpp" diff --git a/library/src/tensor_operation_instance/gpu/reduce/device_reduce_instance_multiblock_atomic_add_f16_f32_f32_avg.cpp b/library/src/tensor_operation_instance/gpu/reduce/device_reduce_instance_multiblock_atomic_add_f16_f32_f32_avg.cpp index a5481784ed675cc36cf6ab8d0ec64d0961934030..5463cc2b4c2e5762f08db5593bfc51e204a41713 100644 --- a/library/src/tensor_operation_instance/gpu/reduce/device_reduce_instance_multiblock_atomic_add_f16_f32_f32_avg.cpp +++ b/library/src/tensor_operation_instance/gpu/reduce/device_reduce_instance_multiblock_atomic_add_f16_f32_f32_avg.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #include "ck/utility/reduction_enums.hpp" #include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_multiblock_atomic_add.hpp" diff --git a/library/src/tensor_operation_instance/gpu/reduce/device_reduce_instance_multiblock_atomic_add_f32_f32_f32_add.cpp b/library/src/tensor_operation_instance/gpu/reduce/device_reduce_instance_multiblock_atomic_add_f32_f32_f32_add.cpp index aa6e6d3cc976981103d8cb226fe100c5161bee59..a15e99fe0f57a2aab28811660dd910645cef482a 100644 --- a/library/src/tensor_operation_instance/gpu/reduce/device_reduce_instance_multiblock_atomic_add_f32_f32_f32_add.cpp +++ b/library/src/tensor_operation_instance/gpu/reduce/device_reduce_instance_multiblock_atomic_add_f32_f32_f32_add.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #include "ck/utility/reduction_enums.hpp" #include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_multiblock_atomic_add.hpp" diff --git a/library/src/tensor_operation_instance/gpu/reduce/device_reduce_instance_multiblock_atomic_add_f32_f32_f32_avg.cpp b/library/src/tensor_operation_instance/gpu/reduce/device_reduce_instance_multiblock_atomic_add_f32_f32_f32_avg.cpp index b1ea551eaa775b867114776e171687ff1ee8ea27..ee7a4dd4cef85adf56094750c09558b3dea1ec5a 100644 --- a/library/src/tensor_operation_instance/gpu/reduce/device_reduce_instance_multiblock_atomic_add_f32_f32_f32_avg.cpp +++ b/library/src/tensor_operation_instance/gpu/reduce/device_reduce_instance_multiblock_atomic_add_f32_f32_f32_avg.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #include "ck/utility/reduction_enums.hpp" #include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_multiblock_atomic_add.hpp" diff --git a/library/src/tensor_operation_instance/gpu/reduce/device_reduce_instance_multiblock_atomic_add_f32_f64_f32_add.cpp b/library/src/tensor_operation_instance/gpu/reduce/device_reduce_instance_multiblock_atomic_add_f32_f64_f32_add.cpp index 2ba83132d8f42cfb4695c93a602112cad34acecc..58f7295a835ac90d6bc13125f2dc290544497bc0 100644 --- a/library/src/tensor_operation_instance/gpu/reduce/device_reduce_instance_multiblock_atomic_add_f32_f64_f32_add.cpp +++ b/library/src/tensor_operation_instance/gpu/reduce/device_reduce_instance_multiblock_atomic_add_f32_f64_f32_add.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #include "ck/utility/reduction_enums.hpp" #include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_multiblock_atomic_add.hpp" diff --git a/library/src/tensor_operation_instance/gpu/reduce/device_reduce_instance_multiblock_atomic_add_f32_f64_f32_avg.cpp b/library/src/tensor_operation_instance/gpu/reduce/device_reduce_instance_multiblock_atomic_add_f32_f64_f32_avg.cpp index b9018e7c6af6e93f00825b0e72b0d3338a788960..43f2370a2c37cbe181ac99b81c51313f1b89d3b3 100644 --- a/library/src/tensor_operation_instance/gpu/reduce/device_reduce_instance_multiblock_atomic_add_f32_f64_f32_avg.cpp +++ b/library/src/tensor_operation_instance/gpu/reduce/device_reduce_instance_multiblock_atomic_add_f32_f64_f32_avg.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #include "ck/utility/reduction_enums.hpp" #include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_multiblock_atomic_add.hpp" diff --git a/library/src/tensor_operation_instance/gpu/reduce/device_reduce_instance_multiblock_atomic_add_f64_f64_f64_add.cpp b/library/src/tensor_operation_instance/gpu/reduce/device_reduce_instance_multiblock_atomic_add_f64_f64_f64_add.cpp index 1abc6de5526d5bd3bb89d95c7e58fb54bc76ef20..eaf84a6de1b7ba55446a2994ff5cdf9eb939e625 100644 --- a/library/src/tensor_operation_instance/gpu/reduce/device_reduce_instance_multiblock_atomic_add_f64_f64_f64_add.cpp +++ b/library/src/tensor_operation_instance/gpu/reduce/device_reduce_instance_multiblock_atomic_add_f64_f64_f64_add.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #include "ck/utility/reduction_enums.hpp" #include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_multiblock_atomic_add.hpp" diff --git a/library/src/tensor_operation_instance/gpu/reduce/device_reduce_instance_multiblock_atomic_add_f64_f64_f64_avg.cpp b/library/src/tensor_operation_instance/gpu/reduce/device_reduce_instance_multiblock_atomic_add_f64_f64_f64_avg.cpp index f3a017aeb422e08573c9d1fbcb58f763baa14b09..76d63c9a96372e9c3cd97de281d4e841d92ee71a 100644 --- a/library/src/tensor_operation_instance/gpu/reduce/device_reduce_instance_multiblock_atomic_add_f64_f64_f64_avg.cpp +++ b/library/src/tensor_operation_instance/gpu/reduce/device_reduce_instance_multiblock_atomic_add_f64_f64_f64_avg.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #include "ck/utility/reduction_enums.hpp" #include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_multiblock_atomic_add.hpp" diff --git a/library/src/tensor_operation_instance/gpu/reduce/device_reduce_instance_threadwise_b16_f32_b16_add.cpp b/library/src/tensor_operation_instance/gpu/reduce/device_reduce_instance_threadwise_b16_f32_b16_add.cpp index 329617bb432d742eab9b3021ebdfc3f5c0812d3d..eb1d389c398fe609b60d16c9202ef2f117cd43ec 100644 --- a/library/src/tensor_operation_instance/gpu/reduce/device_reduce_instance_threadwise_b16_f32_b16_add.cpp +++ b/library/src/tensor_operation_instance/gpu/reduce/device_reduce_instance_threadwise_b16_f32_b16_add.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #include "ck/utility/reduction_enums.hpp" #include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_threadwise.hpp" diff --git a/library/src/tensor_operation_instance/gpu/reduce/device_reduce_instance_threadwise_b16_f32_b16_amax.cpp b/library/src/tensor_operation_instance/gpu/reduce/device_reduce_instance_threadwise_b16_f32_b16_amax.cpp index 1e4d43debd2992fee6add61f9ea5f339f7cc4db9..61225df05dd1d0e444c76fdf3c0a318b80d17380 100644 --- a/library/src/tensor_operation_instance/gpu/reduce/device_reduce_instance_threadwise_b16_f32_b16_amax.cpp +++ b/library/src/tensor_operation_instance/gpu/reduce/device_reduce_instance_threadwise_b16_f32_b16_amax.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #include "ck/utility/reduction_enums.hpp" #include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_threadwise.hpp" diff --git a/library/src/tensor_operation_instance/gpu/reduce/device_reduce_instance_threadwise_b16_f32_b16_avg.cpp b/library/src/tensor_operation_instance/gpu/reduce/device_reduce_instance_threadwise_b16_f32_b16_avg.cpp index f9f79675f59711b6dee3aac4f8a3cc1d0a4eaad8..aef231e17ac12f2b6380213ce810831da52c7ad1 100644 --- a/library/src/tensor_operation_instance/gpu/reduce/device_reduce_instance_threadwise_b16_f32_b16_avg.cpp +++ b/library/src/tensor_operation_instance/gpu/reduce/device_reduce_instance_threadwise_b16_f32_b16_avg.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #include "ck/utility/reduction_enums.hpp" #include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_threadwise.hpp" diff --git a/library/src/tensor_operation_instance/gpu/reduce/device_reduce_instance_threadwise_b16_f32_b16_max.cpp b/library/src/tensor_operation_instance/gpu/reduce/device_reduce_instance_threadwise_b16_f32_b16_max.cpp index d3e7268c5a53106c85ec3dc19400191874def55d..6d2939e3e0e3ab7e196c390b4758bd0ab0345ba2 100644 --- a/library/src/tensor_operation_instance/gpu/reduce/device_reduce_instance_threadwise_b16_f32_b16_max.cpp +++ b/library/src/tensor_operation_instance/gpu/reduce/device_reduce_instance_threadwise_b16_f32_b16_max.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #include "ck/utility/reduction_enums.hpp" #include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_threadwise.hpp" diff --git a/library/src/tensor_operation_instance/gpu/reduce/device_reduce_instance_threadwise_b16_f32_b16_min.cpp b/library/src/tensor_operation_instance/gpu/reduce/device_reduce_instance_threadwise_b16_f32_b16_min.cpp index a41a12386d021c952d1c014229b1de233290539c..ff0ddd773b75603ca245bd5ca3821d5443db00d9 100644 --- a/library/src/tensor_operation_instance/gpu/reduce/device_reduce_instance_threadwise_b16_f32_b16_min.cpp +++ b/library/src/tensor_operation_instance/gpu/reduce/device_reduce_instance_threadwise_b16_f32_b16_min.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #include "ck/utility/reduction_enums.hpp" #include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_threadwise.hpp" diff --git a/library/src/tensor_operation_instance/gpu/reduce/device_reduce_instance_threadwise_b16_f32_b16_norm2.cpp b/library/src/tensor_operation_instance/gpu/reduce/device_reduce_instance_threadwise_b16_f32_b16_norm2.cpp index 6da1acc4e54de8fc86fd60ce449bc759f8cfce98..1367022065226226457941e729b71c0d1b66b387 100644 --- a/library/src/tensor_operation_instance/gpu/reduce/device_reduce_instance_threadwise_b16_f32_b16_norm2.cpp +++ b/library/src/tensor_operation_instance/gpu/reduce/device_reduce_instance_threadwise_b16_f32_b16_norm2.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #include "ck/utility/reduction_enums.hpp" #include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_threadwise.hpp" diff --git a/library/src/tensor_operation_instance/gpu/reduce/device_reduce_instance_threadwise_f16_f16_f16_amax.cpp b/library/src/tensor_operation_instance/gpu/reduce/device_reduce_instance_threadwise_f16_f16_f16_amax.cpp index f14b8a40312d265e231ded1d4dd4b5bc47fcc8ea..f0f93019189aa751fbc39dcafed73cc42fd0f85c 100644 --- a/library/src/tensor_operation_instance/gpu/reduce/device_reduce_instance_threadwise_f16_f16_f16_amax.cpp +++ b/library/src/tensor_operation_instance/gpu/reduce/device_reduce_instance_threadwise_f16_f16_f16_amax.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #include "ck/utility/reduction_enums.hpp" #include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_threadwise.hpp" diff --git a/library/src/tensor_operation_instance/gpu/reduce/device_reduce_instance_threadwise_f16_f16_f16_max.cpp b/library/src/tensor_operation_instance/gpu/reduce/device_reduce_instance_threadwise_f16_f16_f16_max.cpp index 5a9f08167c9a027058edc7a513781e26804513c7..db9f1566008d81c0c5d5e09f2b73c13d7c77fdad 100644 --- a/library/src/tensor_operation_instance/gpu/reduce/device_reduce_instance_threadwise_f16_f16_f16_max.cpp +++ b/library/src/tensor_operation_instance/gpu/reduce/device_reduce_instance_threadwise_f16_f16_f16_max.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #include "ck/utility/reduction_enums.hpp" #include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_threadwise.hpp" diff --git a/library/src/tensor_operation_instance/gpu/reduce/device_reduce_instance_threadwise_f16_f16_f16_min.cpp b/library/src/tensor_operation_instance/gpu/reduce/device_reduce_instance_threadwise_f16_f16_f16_min.cpp index d3aff06753a1c689fd6362fafdb4b4ca9a3d940f..e5644faf291740876ee4daeaf66fc0e43e20b0ce 100644 --- a/library/src/tensor_operation_instance/gpu/reduce/device_reduce_instance_threadwise_f16_f16_f16_min.cpp +++ b/library/src/tensor_operation_instance/gpu/reduce/device_reduce_instance_threadwise_f16_f16_f16_min.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #include "ck/utility/reduction_enums.hpp" #include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_threadwise.hpp" diff --git a/library/src/tensor_operation_instance/gpu/reduce/device_reduce_instance_threadwise_f16_f32_f16_add.cpp b/library/src/tensor_operation_instance/gpu/reduce/device_reduce_instance_threadwise_f16_f32_f16_add.cpp index 55f7537d8334586c4aa6ff52dc8e20cb63650efd..350259f5f0c1ebc258bd202f182e54d600b619cd 100644 --- a/library/src/tensor_operation_instance/gpu/reduce/device_reduce_instance_threadwise_f16_f32_f16_add.cpp +++ b/library/src/tensor_operation_instance/gpu/reduce/device_reduce_instance_threadwise_f16_f32_f16_add.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #include "ck/utility/reduction_enums.hpp" #include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_threadwise.hpp" diff --git a/library/src/tensor_operation_instance/gpu/reduce/device_reduce_instance_threadwise_f16_f32_f16_avg.cpp b/library/src/tensor_operation_instance/gpu/reduce/device_reduce_instance_threadwise_f16_f32_f16_avg.cpp index 70f91168d8bc3025758f64186a1a379b89ce3b8e..89cae5527342997a0db882b1c438a8dc117e972f 100644 --- a/library/src/tensor_operation_instance/gpu/reduce/device_reduce_instance_threadwise_f16_f32_f16_avg.cpp +++ b/library/src/tensor_operation_instance/gpu/reduce/device_reduce_instance_threadwise_f16_f32_f16_avg.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #include "ck/utility/reduction_enums.hpp" #include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_threadwise.hpp" diff --git a/library/src/tensor_operation_instance/gpu/reduce/device_reduce_instance_threadwise_f16_f32_f16_norm2.cpp b/library/src/tensor_operation_instance/gpu/reduce/device_reduce_instance_threadwise_f16_f32_f16_norm2.cpp index 47f5e67fe080d8c6ef3bb29b25ac4ed74a80d4ab..b0f82ab45350b5bb71e15f609dc0e01264df8728 100644 --- a/library/src/tensor_operation_instance/gpu/reduce/device_reduce_instance_threadwise_f16_f32_f16_norm2.cpp +++ b/library/src/tensor_operation_instance/gpu/reduce/device_reduce_instance_threadwise_f16_f32_f16_norm2.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #include "ck/utility/reduction_enums.hpp" #include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_threadwise.hpp" diff --git a/library/src/tensor_operation_instance/gpu/reduce/device_reduce_instance_threadwise_f32_f32_f32_add.cpp b/library/src/tensor_operation_instance/gpu/reduce/device_reduce_instance_threadwise_f32_f32_f32_add.cpp index eae489ff04d08ccb453ce1088f7140dafaecbd89..fedd0f890bd8987a1bf1b9b7f29d99c5293676f2 100644 --- a/library/src/tensor_operation_instance/gpu/reduce/device_reduce_instance_threadwise_f32_f32_f32_add.cpp +++ b/library/src/tensor_operation_instance/gpu/reduce/device_reduce_instance_threadwise_f32_f32_f32_add.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #include "ck/utility/reduction_enums.hpp" #include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_threadwise.hpp" diff --git a/library/src/tensor_operation_instance/gpu/reduce/device_reduce_instance_threadwise_f32_f32_f32_amax.cpp b/library/src/tensor_operation_instance/gpu/reduce/device_reduce_instance_threadwise_f32_f32_f32_amax.cpp index 9fb267a201724bb925884efe625160817765257f..64ec1ce1574a3608be4db2a88b02d26c1f1c0835 100644 --- a/library/src/tensor_operation_instance/gpu/reduce/device_reduce_instance_threadwise_f32_f32_f32_amax.cpp +++ b/library/src/tensor_operation_instance/gpu/reduce/device_reduce_instance_threadwise_f32_f32_f32_amax.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #include "ck/utility/reduction_enums.hpp" #include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_threadwise.hpp" diff --git a/library/src/tensor_operation_instance/gpu/reduce/device_reduce_instance_threadwise_f32_f32_f32_avg.cpp b/library/src/tensor_operation_instance/gpu/reduce/device_reduce_instance_threadwise_f32_f32_f32_avg.cpp index fecb2691f8679cf97f489ffb09619c8bd092c16f..156f182f55a1e515e337b7a3cd20f124f618fcec 100644 --- a/library/src/tensor_operation_instance/gpu/reduce/device_reduce_instance_threadwise_f32_f32_f32_avg.cpp +++ b/library/src/tensor_operation_instance/gpu/reduce/device_reduce_instance_threadwise_f32_f32_f32_avg.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #include "ck/utility/reduction_enums.hpp" #include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_threadwise.hpp" diff --git a/library/src/tensor_operation_instance/gpu/reduce/device_reduce_instance_threadwise_f32_f32_f32_max.cpp b/library/src/tensor_operation_instance/gpu/reduce/device_reduce_instance_threadwise_f32_f32_f32_max.cpp index 232d2b858bcf651001fb4195b981d1e0c0f6c9ec..22b278790d970dbb4cb68741f653606cbd922abb 100644 --- a/library/src/tensor_operation_instance/gpu/reduce/device_reduce_instance_threadwise_f32_f32_f32_max.cpp +++ b/library/src/tensor_operation_instance/gpu/reduce/device_reduce_instance_threadwise_f32_f32_f32_max.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #include "ck/utility/reduction_enums.hpp" #include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_threadwise.hpp" diff --git a/library/src/tensor_operation_instance/gpu/reduce/device_reduce_instance_threadwise_f32_f32_f32_min.cpp b/library/src/tensor_operation_instance/gpu/reduce/device_reduce_instance_threadwise_f32_f32_f32_min.cpp index 07d45c4ca994aae5384732901783cb69c16c2c4c..f3070d297587f8510ee21b553d7299a9cd30d745 100644 --- a/library/src/tensor_operation_instance/gpu/reduce/device_reduce_instance_threadwise_f32_f32_f32_min.cpp +++ b/library/src/tensor_operation_instance/gpu/reduce/device_reduce_instance_threadwise_f32_f32_f32_min.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #include "ck/utility/reduction_enums.hpp" #include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_threadwise.hpp" diff --git a/library/src/tensor_operation_instance/gpu/reduce/device_reduce_instance_threadwise_f32_f32_f32_norm2.cpp b/library/src/tensor_operation_instance/gpu/reduce/device_reduce_instance_threadwise_f32_f32_f32_norm2.cpp index 596a062f3a227b71dcb1ed15122f31b0017e584b..e17476ba75111134f67457df836415323c93a2ce 100644 --- a/library/src/tensor_operation_instance/gpu/reduce/device_reduce_instance_threadwise_f32_f32_f32_norm2.cpp +++ b/library/src/tensor_operation_instance/gpu/reduce/device_reduce_instance_threadwise_f32_f32_f32_norm2.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #include "ck/utility/reduction_enums.hpp" #include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_threadwise.hpp" diff --git a/library/src/tensor_operation_instance/gpu/reduce/device_reduce_instance_threadwise_f32_f64_f32_add.cpp b/library/src/tensor_operation_instance/gpu/reduce/device_reduce_instance_threadwise_f32_f64_f32_add.cpp index 7270cefe8b4a8ae83c374893a5d3bf57070afd86..2387725b9f74b543e1db6ca3de372838a9badaea 100644 --- a/library/src/tensor_operation_instance/gpu/reduce/device_reduce_instance_threadwise_f32_f64_f32_add.cpp +++ b/library/src/tensor_operation_instance/gpu/reduce/device_reduce_instance_threadwise_f32_f64_f32_add.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #include "ck/utility/reduction_enums.hpp" #include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_threadwise.hpp" diff --git a/library/src/tensor_operation_instance/gpu/reduce/device_reduce_instance_threadwise_f32_f64_f32_avg.cpp b/library/src/tensor_operation_instance/gpu/reduce/device_reduce_instance_threadwise_f32_f64_f32_avg.cpp index d0f4ef3dff503c92a02713b4832776f6a182ea2a..8be94118c4864ec38087204d87f1fbb83b85ec43 100644 --- a/library/src/tensor_operation_instance/gpu/reduce/device_reduce_instance_threadwise_f32_f64_f32_avg.cpp +++ b/library/src/tensor_operation_instance/gpu/reduce/device_reduce_instance_threadwise_f32_f64_f32_avg.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #include "ck/utility/reduction_enums.hpp" #include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_threadwise.hpp" diff --git a/library/src/tensor_operation_instance/gpu/reduce/device_reduce_instance_threadwise_f32_f64_f32_norm2.cpp b/library/src/tensor_operation_instance/gpu/reduce/device_reduce_instance_threadwise_f32_f64_f32_norm2.cpp index 9c6bce92f45de792a722a32630d023ed6ea177a1..9ff01615c9c83a51ac995cd5a7a6eb840c6fafa5 100644 --- a/library/src/tensor_operation_instance/gpu/reduce/device_reduce_instance_threadwise_f32_f64_f32_norm2.cpp +++ b/library/src/tensor_operation_instance/gpu/reduce/device_reduce_instance_threadwise_f32_f64_f32_norm2.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #include "ck/utility/reduction_enums.hpp" #include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_threadwise.hpp" diff --git a/library/src/tensor_operation_instance/gpu/reduce/device_reduce_instance_threadwise_f64_f64_f64_add.cpp b/library/src/tensor_operation_instance/gpu/reduce/device_reduce_instance_threadwise_f64_f64_f64_add.cpp index 5faf8d8283147fc9e56f74f375761a8f6b5022df..8a99c7c95879554615a61e46a998ef213c9a5cdb 100644 --- a/library/src/tensor_operation_instance/gpu/reduce/device_reduce_instance_threadwise_f64_f64_f64_add.cpp +++ b/library/src/tensor_operation_instance/gpu/reduce/device_reduce_instance_threadwise_f64_f64_f64_add.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #include "ck/utility/reduction_enums.hpp" #include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_threadwise.hpp" diff --git a/library/src/tensor_operation_instance/gpu/reduce/device_reduce_instance_threadwise_f64_f64_f64_amax.cpp b/library/src/tensor_operation_instance/gpu/reduce/device_reduce_instance_threadwise_f64_f64_f64_amax.cpp index 8f3c72451a3e05bf75398c6586fa019b2572f698..78909c8dbf5f35e0241a1802c4e589085b4e7ead 100644 --- a/library/src/tensor_operation_instance/gpu/reduce/device_reduce_instance_threadwise_f64_f64_f64_amax.cpp +++ b/library/src/tensor_operation_instance/gpu/reduce/device_reduce_instance_threadwise_f64_f64_f64_amax.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #include "ck/utility/reduction_enums.hpp" #include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_threadwise.hpp" diff --git a/library/src/tensor_operation_instance/gpu/reduce/device_reduce_instance_threadwise_f64_f64_f64_avg.cpp b/library/src/tensor_operation_instance/gpu/reduce/device_reduce_instance_threadwise_f64_f64_f64_avg.cpp index 8d7794f42e3b4d52d510bd14bef0186e3c357376..7f094c28b756e020a47b1a78454f839ac9fb540e 100644 --- a/library/src/tensor_operation_instance/gpu/reduce/device_reduce_instance_threadwise_f64_f64_f64_avg.cpp +++ b/library/src/tensor_operation_instance/gpu/reduce/device_reduce_instance_threadwise_f64_f64_f64_avg.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #include "ck/utility/reduction_enums.hpp" #include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_threadwise.hpp" diff --git a/library/src/tensor_operation_instance/gpu/reduce/device_reduce_instance_threadwise_f64_f64_f64_max.cpp b/library/src/tensor_operation_instance/gpu/reduce/device_reduce_instance_threadwise_f64_f64_f64_max.cpp index 4a32543a136d435d1a1863aebb238904d0e34c1c..4dfb477c1e74f4a3863df56fddbe815ea5d8a321 100644 --- a/library/src/tensor_operation_instance/gpu/reduce/device_reduce_instance_threadwise_f64_f64_f64_max.cpp +++ b/library/src/tensor_operation_instance/gpu/reduce/device_reduce_instance_threadwise_f64_f64_f64_max.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #include "ck/utility/reduction_enums.hpp" #include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_threadwise.hpp" diff --git a/library/src/tensor_operation_instance/gpu/reduce/device_reduce_instance_threadwise_f64_f64_f64_min.cpp b/library/src/tensor_operation_instance/gpu/reduce/device_reduce_instance_threadwise_f64_f64_f64_min.cpp index 26d571c84e5295fdcd6e5dbfb7827889c75efc31..e4928336069cc19a6d161d035407ed8a63deefaf 100644 --- a/library/src/tensor_operation_instance/gpu/reduce/device_reduce_instance_threadwise_f64_f64_f64_min.cpp +++ b/library/src/tensor_operation_instance/gpu/reduce/device_reduce_instance_threadwise_f64_f64_f64_min.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #include "ck/utility/reduction_enums.hpp" #include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_threadwise.hpp" diff --git a/library/src/tensor_operation_instance/gpu/reduce/device_reduce_instance_threadwise_f64_f64_f64_norm2.cpp b/library/src/tensor_operation_instance/gpu/reduce/device_reduce_instance_threadwise_f64_f64_f64_norm2.cpp index ae56a2a91999c686cb56d358311bcd8bc320d376..de8dce357bddc3470d8d17142b11f0cac7c6a507 100644 --- a/library/src/tensor_operation_instance/gpu/reduce/device_reduce_instance_threadwise_f64_f64_f64_norm2.cpp +++ b/library/src/tensor_operation_instance/gpu/reduce/device_reduce_instance_threadwise_f64_f64_f64_norm2.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #include "ck/utility/reduction_enums.hpp" #include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_threadwise.hpp" diff --git a/library/src/tensor_operation_instance/gpu/reduce/device_reduce_instance_threadwise_i8_i32_i8_add.cpp b/library/src/tensor_operation_instance/gpu/reduce/device_reduce_instance_threadwise_i8_i32_i8_add.cpp index aae3233c9f0be15daa3c9004881f1a7afc85c89c..0225e75cb09a770c94597dd0caffee13f05b1b27 100644 --- a/library/src/tensor_operation_instance/gpu/reduce/device_reduce_instance_threadwise_i8_i32_i8_add.cpp +++ b/library/src/tensor_operation_instance/gpu/reduce/device_reduce_instance_threadwise_i8_i32_i8_add.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #include "ck/utility/reduction_enums.hpp" #include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_threadwise.hpp" diff --git a/library/src/tensor_operation_instance/gpu/reduce/device_reduce_instance_threadwise_i8_i32_i8_avg.cpp b/library/src/tensor_operation_instance/gpu/reduce/device_reduce_instance_threadwise_i8_i32_i8_avg.cpp index 94d5d3fa2f45f31647779ea451ba1ac4b833c9c8..6d04d2128f651c45b9abd4c2265c89c60495d45d 100644 --- a/library/src/tensor_operation_instance/gpu/reduce/device_reduce_instance_threadwise_i8_i32_i8_avg.cpp +++ b/library/src/tensor_operation_instance/gpu/reduce/device_reduce_instance_threadwise_i8_i32_i8_avg.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #include "ck/utility/reduction_enums.hpp" #include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_threadwise.hpp" diff --git a/library/src/tensor_operation_instance/gpu/reduce/device_reduce_instance_threadwise_i8_i8_i8_amax.cpp b/library/src/tensor_operation_instance/gpu/reduce/device_reduce_instance_threadwise_i8_i8_i8_amax.cpp index dad190a63484a50121c71096547264659115219c..046643abe1a6cb951204d3e624f89f678be71176 100644 --- a/library/src/tensor_operation_instance/gpu/reduce/device_reduce_instance_threadwise_i8_i8_i8_amax.cpp +++ b/library/src/tensor_operation_instance/gpu/reduce/device_reduce_instance_threadwise_i8_i8_i8_amax.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #include "ck/utility/reduction_enums.hpp" #include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_threadwise.hpp" diff --git a/library/src/tensor_operation_instance/gpu/reduce/device_reduce_instance_threadwise_i8_i8_i8_max.cpp b/library/src/tensor_operation_instance/gpu/reduce/device_reduce_instance_threadwise_i8_i8_i8_max.cpp index b7ca6998f5e47d732b5998c52e85d5b9dcef9fd4..a61302c001f7ecbcad117f5f8df38cc9ec9c3c14 100644 --- a/library/src/tensor_operation_instance/gpu/reduce/device_reduce_instance_threadwise_i8_i8_i8_max.cpp +++ b/library/src/tensor_operation_instance/gpu/reduce/device_reduce_instance_threadwise_i8_i8_i8_max.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #include "ck/utility/reduction_enums.hpp" #include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_threadwise.hpp" diff --git a/library/src/tensor_operation_instance/gpu/reduce/device_reduce_instance_threadwise_i8_i8_i8_min.cpp b/library/src/tensor_operation_instance/gpu/reduce/device_reduce_instance_threadwise_i8_i8_i8_min.cpp index 22c40187ea9be8f1163693908c60a39a50f13d45..7619e12a8e6bfe37e09abccb2f2672b9458cc934 100644 --- a/library/src/tensor_operation_instance/gpu/reduce/device_reduce_instance_threadwise_i8_i8_i8_min.cpp +++ b/library/src/tensor_operation_instance/gpu/reduce/device_reduce_instance_threadwise_i8_i8_i8_min.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #include "ck/utility/reduction_enums.hpp" #include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_threadwise.hpp" diff --git a/library/src/tensor_operation_instance/gpu/softmax/CMakeLists.txt b/library/src/tensor_operation_instance/gpu/softmax/CMakeLists.txt index fc13261a6a74c49594240969cdae03ee4f5e5b7c..ba0197477f507f163b9d88f25cf81375649b5fb4 100644 --- a/library/src/tensor_operation_instance/gpu/softmax/CMakeLists.txt +++ b/library/src/tensor_operation_instance/gpu/softmax/CMakeLists.txt @@ -1,26 +1,20 @@ -add_instance_library(device_softmax_instance - device_softmax_i8_i8_instance.cpp - device_softmax_i8_i8_instance_rank3_reduce1.cpp - device_softmax_i8_i8_instance_rank3_reduce2.cpp - device_softmax_i8_i8_instance_rank3_reduce3.cpp - device_softmax_i8_i8_instance_rank4_reduce1.cpp - device_softmax_i8_i8_instance_rank4_reduce2.cpp - device_softmax_i8_i8_instance_rank4_reduce3.cpp - device_softmax_i8_i8_instance_rank4_reduce4.cpp - device_softmax_f16_f16_instance.cpp - device_softmax_f16_f16_instance_rank3_reduce1.cpp +set(DEVICE_SOFTMAX_INSTANCES) +if(DTYPES MATCHES "fp16" OR NOT DEFINED DTYPES) + list(APPEND DEVICE_SOFTMAX_INSTANCES device_softmax_f16_f16_instance_rank3_reduce1.cpp device_softmax_f16_f16_instance_rank3_reduce2.cpp device_softmax_f16_f16_instance_rank3_reduce3.cpp device_softmax_f16_f16_instance_rank4_reduce1.cpp device_softmax_f16_f16_instance_rank4_reduce2.cpp device_softmax_f16_f16_instance_rank4_reduce3.cpp - device_softmax_f16_f16_instance_rank4_reduce4.cpp - device_softmax_f32_f32_instance.cpp - device_softmax_f32_f32_instance_rank3_reduce1.cpp + device_softmax_f16_f16_instance_rank4_reduce4.cpp) +endif() +if(DTYPES MATCHES "fp32" OR NOT DEFINED DTYPES) + list(APPEND DEVICE_SOFTMAX_INSTANCES device_softmax_f32_f32_instance_rank3_reduce1.cpp device_softmax_f32_f32_instance_rank3_reduce2.cpp device_softmax_f32_f32_instance_rank3_reduce3.cpp device_softmax_f32_f32_instance_rank4_reduce1.cpp device_softmax_f32_f32_instance_rank4_reduce2.cpp device_softmax_f32_f32_instance_rank4_reduce3.cpp - device_softmax_f32_f32_instance_rank4_reduce4.cpp -) + device_softmax_f32_f32_instance_rank4_reduce4.cpp) +endif() +add_instance_library(device_softmax_instance ${DEVICE_SOFTMAX_INSTANCES}) diff --git a/library/src/tensor_operation_instance/gpu/softmax/device_softmax_f16_f16_instance.cpp b/library/src/tensor_operation_instance/gpu/softmax/device_softmax_f16_f16_instance.cpp deleted file mode 100644 index 14d2764529c826587b13661e7a7cab11e8d6ea99..0000000000000000000000000000000000000000 --- a/library/src/tensor_operation_instance/gpu/softmax/device_softmax_f16_f16_instance.cpp +++ /dev/null @@ -1,40 +0,0 @@ -// SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. - -#include - -#include "ck/library/tensor_operation_instance/gpu/softmax/device_softmax_f16_f16_instance_rank3_reduce1.hpp" -#include "ck/library/tensor_operation_instance/gpu/softmax/device_softmax_f16_f16_instance_rank3_reduce2.hpp" -#include "ck/library/tensor_operation_instance/gpu/softmax/device_softmax_f16_f16_instance_rank3_reduce3.hpp" - -#include "ck/library/tensor_operation_instance/gpu/softmax/device_softmax_f16_f16_instance_rank4_reduce1.hpp" -#include "ck/library/tensor_operation_instance/gpu/softmax/device_softmax_f16_f16_instance_rank4_reduce2.hpp" -#include "ck/library/tensor_operation_instance/gpu/softmax/device_softmax_f16_f16_instance_rank4_reduce3.hpp" -#include "ck/library/tensor_operation_instance/gpu/softmax/device_softmax_f16_f16_instance_rank4_reduce4.hpp" - -namespace ck { -namespace tensor_operation { -namespace device { -namespace instance { - -void add_device_softmax_f16_f16_rank3_instances( - std::vector>& instances) -{ - add_device_softmax_f16_f16_rank3_reduce1_instances(instances); - add_device_softmax_f16_f16_rank3_reduce2_instances(instances); - add_device_softmax_f16_f16_rank3_reduce3_instances(instances); -} - -void add_device_softmax_f16_f16_rank4_instances( - std::vector>& instances) -{ - add_device_softmax_f16_f16_rank4_reduce1_instances(instances); - add_device_softmax_f16_f16_rank4_reduce2_instances(instances); - add_device_softmax_f16_f16_rank4_reduce3_instances(instances); - add_device_softmax_f16_f16_rank4_reduce4_instances(instances); -} - -} // namespace instance -} // namespace device -} // namespace tensor_operation -} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/softmax/device_softmax_f16_f16_instance_rank3_reduce1.cpp b/library/src/tensor_operation_instance/gpu/softmax/device_softmax_f16_f16_instance_rank3_reduce1.cpp index fa334b997c27af20f05c42335bde28f5f32f819a..36867d993f9163bd7bdb8e81eb0ed51920ad2458 100644 --- a/library/src/tensor_operation_instance/gpu/softmax/device_softmax_f16_f16_instance_rank3_reduce1.cpp +++ b/library/src/tensor_operation_instance/gpu/softmax/device_softmax_f16_f16_instance_rank3_reduce1.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #include @@ -13,12 +13,11 @@ namespace tensor_operation { namespace device { namespace instance { -static constexpr index_t RANK = 3; - void add_device_softmax_f16_f16_rank3_reduce1_instances( - std::vector>& instances) + std::vector>& instances) { - add_device_operation_instances(instances, device_softmax_f16_f16_instances{}); + add_device_operation_instances(instances, device_softmax_f16_f16_generic_instance<3, 1>{}); + add_device_operation_instances(instances, device_softmax_f16_f16_instances<3, 1>{}); } } // namespace instance diff --git a/library/src/tensor_operation_instance/gpu/softmax/device_softmax_f16_f16_instance_rank3_reduce2.cpp b/library/src/tensor_operation_instance/gpu/softmax/device_softmax_f16_f16_instance_rank3_reduce2.cpp index 1c9d37d8483b7f249a0b1c05551c70acdb46db03..373f33ad59716bcc2ddeb628a92dd0d46f7ed0aa 100644 --- a/library/src/tensor_operation_instance/gpu/softmax/device_softmax_f16_f16_instance_rank3_reduce2.cpp +++ b/library/src/tensor_operation_instance/gpu/softmax/device_softmax_f16_f16_instance_rank3_reduce2.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #include @@ -13,12 +13,11 @@ namespace tensor_operation { namespace device { namespace instance { -static constexpr index_t RANK = 3; - void add_device_softmax_f16_f16_rank3_reduce2_instances( - std::vector>& instances) + std::vector>& instances) { - add_device_operation_instances(instances, device_softmax_f16_f16_instances{}); + add_device_operation_instances(instances, device_softmax_f16_f16_generic_instance<3, 2>{}); + add_device_operation_instances(instances, device_softmax_f16_f16_instances<3, 2>{}); } } // namespace instance diff --git a/library/src/tensor_operation_instance/gpu/softmax/device_softmax_f16_f16_instance_rank3_reduce3.cpp b/library/src/tensor_operation_instance/gpu/softmax/device_softmax_f16_f16_instance_rank3_reduce3.cpp index 5fbdab5055edb0f124bbc6541fcc705a12b71d63..d26b92b4f4998919d046401fb5f813594def112c 100644 --- a/library/src/tensor_operation_instance/gpu/softmax/device_softmax_f16_f16_instance_rank3_reduce3.cpp +++ b/library/src/tensor_operation_instance/gpu/softmax/device_softmax_f16_f16_instance_rank3_reduce3.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #include @@ -13,12 +13,11 @@ namespace tensor_operation { namespace device { namespace instance { -static constexpr index_t RANK = 3; - void add_device_softmax_f16_f16_rank3_reduce3_instances( - std::vector>& instances) + std::vector>& instances) { - add_device_operation_instances(instances, device_softmax_f16_f16_instances{}); + add_device_operation_instances(instances, device_softmax_f16_f16_generic_instance<3, 3>{}); + add_device_operation_instances(instances, device_softmax_f16_f16_instances<3, 3>{}); } } // namespace instance diff --git a/library/src/tensor_operation_instance/gpu/softmax/device_softmax_f16_f16_instance_rank4_reduce1.cpp b/library/src/tensor_operation_instance/gpu/softmax/device_softmax_f16_f16_instance_rank4_reduce1.cpp index 7dd8640b187a2b793d5ad6931ab329ae024ca432..bbb735b6fe564e586cc269eb030d2dc38d4f7cd1 100644 --- a/library/src/tensor_operation_instance/gpu/softmax/device_softmax_f16_f16_instance_rank4_reduce1.cpp +++ b/library/src/tensor_operation_instance/gpu/softmax/device_softmax_f16_f16_instance_rank4_reduce1.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #include @@ -13,12 +13,11 @@ namespace tensor_operation { namespace device { namespace instance { -static constexpr index_t RANK = 4; - void add_device_softmax_f16_f16_rank4_reduce1_instances( - std::vector>& instances) + std::vector>& instances) { - add_device_operation_instances(instances, device_softmax_f16_f16_instances{}); + add_device_operation_instances(instances, device_softmax_f16_f16_generic_instance<4, 1>{}); + add_device_operation_instances(instances, device_softmax_f16_f16_instances<4, 1>{}); } } // namespace instance diff --git a/library/src/tensor_operation_instance/gpu/softmax/device_softmax_f16_f16_instance_rank4_reduce2.cpp b/library/src/tensor_operation_instance/gpu/softmax/device_softmax_f16_f16_instance_rank4_reduce2.cpp index b32fe6838f83440e98f524c59af9a383760dfe7d..92dbe6776039a9d8e3fd2d358abac54d33fe4537 100644 --- a/library/src/tensor_operation_instance/gpu/softmax/device_softmax_f16_f16_instance_rank4_reduce2.cpp +++ b/library/src/tensor_operation_instance/gpu/softmax/device_softmax_f16_f16_instance_rank4_reduce2.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #include @@ -13,12 +13,11 @@ namespace tensor_operation { namespace device { namespace instance { -static constexpr index_t RANK = 4; - void add_device_softmax_f16_f16_rank4_reduce2_instances( - std::vector>& instances) + std::vector>& instances) { - add_device_operation_instances(instances, device_softmax_f16_f16_instances{}); + add_device_operation_instances(instances, device_softmax_f16_f16_generic_instance<4, 2>{}); + add_device_operation_instances(instances, device_softmax_f16_f16_instances<4, 2>{}); } } // namespace instance diff --git a/library/src/tensor_operation_instance/gpu/softmax/device_softmax_f16_f16_instance_rank4_reduce3.cpp b/library/src/tensor_operation_instance/gpu/softmax/device_softmax_f16_f16_instance_rank4_reduce3.cpp index c05048ec567bd6471fca14c41d062eaf8c0c86a5..354cda85d757cb04d2b803ffec32e3a6c36f1856 100644 --- a/library/src/tensor_operation_instance/gpu/softmax/device_softmax_f16_f16_instance_rank4_reduce3.cpp +++ b/library/src/tensor_operation_instance/gpu/softmax/device_softmax_f16_f16_instance_rank4_reduce3.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #include @@ -13,12 +13,11 @@ namespace tensor_operation { namespace device { namespace instance { -static constexpr index_t RANK = 4; - void add_device_softmax_f16_f16_rank4_reduce3_instances( - std::vector>& instances) + std::vector>& instances) { - add_device_operation_instances(instances, device_softmax_f16_f16_instances{}); + add_device_operation_instances(instances, device_softmax_f16_f16_generic_instance<4, 3>{}); + add_device_operation_instances(instances, device_softmax_f16_f16_instances<4, 3>{}); } } // namespace instance diff --git a/library/src/tensor_operation_instance/gpu/softmax/device_softmax_f16_f16_instance_rank4_reduce4.cpp b/library/src/tensor_operation_instance/gpu/softmax/device_softmax_f16_f16_instance_rank4_reduce4.cpp index 6a235708bd426c7270f16246fa058eae45cdd056..edb5e42c103badf9fd65f482577db9ccc2ee08ab 100644 --- a/library/src/tensor_operation_instance/gpu/softmax/device_softmax_f16_f16_instance_rank4_reduce4.cpp +++ b/library/src/tensor_operation_instance/gpu/softmax/device_softmax_f16_f16_instance_rank4_reduce4.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #include @@ -13,12 +13,11 @@ namespace tensor_operation { namespace device { namespace instance { -static constexpr index_t RANK = 4; - void add_device_softmax_f16_f16_rank4_reduce4_instances( - std::vector>& instances) + std::vector>& instances) { - add_device_operation_instances(instances, device_softmax_f16_f16_instances{}); + add_device_operation_instances(instances, device_softmax_f16_f16_generic_instance<4, 4>{}); + add_device_operation_instances(instances, device_softmax_f16_f16_instances<4, 4>{}); } } // namespace instance diff --git a/library/src/tensor_operation_instance/gpu/softmax/device_softmax_f32_f32_instance.cpp b/library/src/tensor_operation_instance/gpu/softmax/device_softmax_f32_f32_instance.cpp deleted file mode 100644 index e5bec5e2639d234b3944f679e99e988f20b2c383..0000000000000000000000000000000000000000 --- a/library/src/tensor_operation_instance/gpu/softmax/device_softmax_f32_f32_instance.cpp +++ /dev/null @@ -1,40 +0,0 @@ -// SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. - -#include - -#include "ck/library/tensor_operation_instance/gpu/softmax/device_softmax_f32_f32_instance_rank3_reduce1.hpp" -#include "ck/library/tensor_operation_instance/gpu/softmax/device_softmax_f32_f32_instance_rank3_reduce2.hpp" -#include "ck/library/tensor_operation_instance/gpu/softmax/device_softmax_f32_f32_instance_rank3_reduce3.hpp" - -#include "ck/library/tensor_operation_instance/gpu/softmax/device_softmax_f32_f32_instance_rank4_reduce1.hpp" -#include "ck/library/tensor_operation_instance/gpu/softmax/device_softmax_f32_f32_instance_rank4_reduce2.hpp" -#include "ck/library/tensor_operation_instance/gpu/softmax/device_softmax_f32_f32_instance_rank4_reduce3.hpp" -#include "ck/library/tensor_operation_instance/gpu/softmax/device_softmax_f32_f32_instance_rank4_reduce4.hpp" - -namespace ck { -namespace tensor_operation { -namespace device { -namespace instance { - -void add_device_softmax_f32_f32_rank3_instances( - std::vector>& instances) -{ - add_device_softmax_f32_f32_rank3_reduce1_instances(instances); - add_device_softmax_f32_f32_rank3_reduce2_instances(instances); - add_device_softmax_f32_f32_rank3_reduce3_instances(instances); -} - -void add_device_softmax_f32_f32_rank4_instances( - std::vector>& instances) -{ - add_device_softmax_f32_f32_rank4_reduce1_instances(instances); - add_device_softmax_f32_f32_rank4_reduce2_instances(instances); - add_device_softmax_f32_f32_rank4_reduce3_instances(instances); - add_device_softmax_f32_f32_rank4_reduce4_instances(instances); -} - -} // namespace instance -} // namespace device -} // namespace tensor_operation -} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/softmax/device_softmax_f32_f32_instance_rank3_reduce1.cpp b/library/src/tensor_operation_instance/gpu/softmax/device_softmax_f32_f32_instance_rank3_reduce1.cpp index 57d3f184a6635f4febe632054c9e39798ef0cc82..566be8fc22c4a5b8a5a6bb17ddebf012bf39d1d4 100644 --- a/library/src/tensor_operation_instance/gpu/softmax/device_softmax_f32_f32_instance_rank3_reduce1.cpp +++ b/library/src/tensor_operation_instance/gpu/softmax/device_softmax_f32_f32_instance_rank3_reduce1.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #include @@ -13,12 +13,11 @@ namespace tensor_operation { namespace device { namespace instance { -static constexpr index_t RANK = 3; - void add_device_softmax_f32_f32_rank3_reduce1_instances( - std::vector>& instances) + std::vector>& instances) { - add_device_operation_instances(instances, device_softmax_f32_f32_instances{}); + add_device_operation_instances(instances, device_softmax_f32_f32_generic_instance<3, 1>{}); + add_device_operation_instances(instances, device_softmax_f32_f32_instances<3, 1>{}); } } // namespace instance diff --git a/library/src/tensor_operation_instance/gpu/softmax/device_softmax_f32_f32_instance_rank3_reduce2.cpp b/library/src/tensor_operation_instance/gpu/softmax/device_softmax_f32_f32_instance_rank3_reduce2.cpp index fae3a4dd6662d8920b1b88c8489e9a6f00f28ed0..f9c76e3116cd412c4efaca27595c5ebedf205943 100644 --- a/library/src/tensor_operation_instance/gpu/softmax/device_softmax_f32_f32_instance_rank3_reduce2.cpp +++ b/library/src/tensor_operation_instance/gpu/softmax/device_softmax_f32_f32_instance_rank3_reduce2.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #include @@ -13,12 +13,11 @@ namespace tensor_operation { namespace device { namespace instance { -static constexpr index_t RANK = 3; - void add_device_softmax_f32_f32_rank3_reduce2_instances( - std::vector>& instances) + std::vector>& instances) { - add_device_operation_instances(instances, device_softmax_f32_f32_instances{}); + add_device_operation_instances(instances, device_softmax_f32_f32_generic_instance<3, 2>{}); + add_device_operation_instances(instances, device_softmax_f32_f32_instances<3, 2>{}); } } // namespace instance diff --git a/library/src/tensor_operation_instance/gpu/softmax/device_softmax_f32_f32_instance_rank3_reduce3.cpp b/library/src/tensor_operation_instance/gpu/softmax/device_softmax_f32_f32_instance_rank3_reduce3.cpp index b6fb70e8e2a6aa097f2a76865f690812dffaffdd..541e0d71a939c5964e560d76771986bdbc1cfc03 100644 --- a/library/src/tensor_operation_instance/gpu/softmax/device_softmax_f32_f32_instance_rank3_reduce3.cpp +++ b/library/src/tensor_operation_instance/gpu/softmax/device_softmax_f32_f32_instance_rank3_reduce3.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #include @@ -13,12 +13,11 @@ namespace tensor_operation { namespace device { namespace instance { -static constexpr index_t RANK = 3; - void add_device_softmax_f32_f32_rank3_reduce3_instances( - std::vector>& instances) + std::vector>& instances) { - add_device_operation_instances(instances, device_softmax_f32_f32_instances{}); + add_device_operation_instances(instances, device_softmax_f32_f32_generic_instance<3, 3>{}); + add_device_operation_instances(instances, device_softmax_f32_f32_instances<3, 3>{}); } } // namespace instance diff --git a/library/src/tensor_operation_instance/gpu/softmax/device_softmax_f32_f32_instance_rank4_reduce1.cpp b/library/src/tensor_operation_instance/gpu/softmax/device_softmax_f32_f32_instance_rank4_reduce1.cpp index 33c7b6f35f351f755a50f48f92d391fd47f40163..95a38df2834b1794384bfdab57087c051a8fc482 100644 --- a/library/src/tensor_operation_instance/gpu/softmax/device_softmax_f32_f32_instance_rank4_reduce1.cpp +++ b/library/src/tensor_operation_instance/gpu/softmax/device_softmax_f32_f32_instance_rank4_reduce1.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #include @@ -13,12 +13,11 @@ namespace tensor_operation { namespace device { namespace instance { -static constexpr index_t RANK = 4; - void add_device_softmax_f32_f32_rank4_reduce1_instances( - std::vector>& instances) + std::vector>& instances) { - add_device_operation_instances(instances, device_softmax_f32_f32_instances{}); + add_device_operation_instances(instances, device_softmax_f32_f32_generic_instance<4, 1>{}); + add_device_operation_instances(instances, device_softmax_f32_f32_instances<4, 1>{}); } } // namespace instance diff --git a/library/src/tensor_operation_instance/gpu/softmax/device_softmax_f32_f32_instance_rank4_reduce2.cpp b/library/src/tensor_operation_instance/gpu/softmax/device_softmax_f32_f32_instance_rank4_reduce2.cpp index c22aa574b1f984d5a4f8c3276c9fb11b74002680..a29b88891d45b7f967629d956b69bc25c371ef88 100644 --- a/library/src/tensor_operation_instance/gpu/softmax/device_softmax_f32_f32_instance_rank4_reduce2.cpp +++ b/library/src/tensor_operation_instance/gpu/softmax/device_softmax_f32_f32_instance_rank4_reduce2.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #include @@ -13,12 +13,11 @@ namespace tensor_operation { namespace device { namespace instance { -static constexpr index_t RANK = 4; - void add_device_softmax_f32_f32_rank4_reduce2_instances( - std::vector>& instances) + std::vector>& instances) { - add_device_operation_instances(instances, device_softmax_f32_f32_instances{}); + add_device_operation_instances(instances, device_softmax_f32_f32_generic_instance<4, 2>{}); + add_device_operation_instances(instances, device_softmax_f32_f32_instances<4, 2>{}); } } // namespace instance diff --git a/library/src/tensor_operation_instance/gpu/softmax/device_softmax_f32_f32_instance_rank4_reduce3.cpp b/library/src/tensor_operation_instance/gpu/softmax/device_softmax_f32_f32_instance_rank4_reduce3.cpp index 55f3d2bd207a0a160f00415592f0106f41be5cee..0da46ea1b47489e9eb9fff991c7c5730e4b5338c 100644 --- a/library/src/tensor_operation_instance/gpu/softmax/device_softmax_f32_f32_instance_rank4_reduce3.cpp +++ b/library/src/tensor_operation_instance/gpu/softmax/device_softmax_f32_f32_instance_rank4_reduce3.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #include @@ -13,12 +13,11 @@ namespace tensor_operation { namespace device { namespace instance { -static constexpr index_t RANK = 4; - void add_device_softmax_f32_f32_rank4_reduce3_instances( - std::vector>& instances) + std::vector>& instances) { - add_device_operation_instances(instances, device_softmax_f32_f32_instances{}); + add_device_operation_instances(instances, device_softmax_f32_f32_generic_instance<4, 3>{}); + add_device_operation_instances(instances, device_softmax_f32_f32_instances<4, 3>{}); } } // namespace instance diff --git a/library/src/tensor_operation_instance/gpu/softmax/device_softmax_f32_f32_instance_rank4_reduce4.cpp b/library/src/tensor_operation_instance/gpu/softmax/device_softmax_f32_f32_instance_rank4_reduce4.cpp index fb0bcf5ee8a2dd6fd5ed4f77c0d8f4be5c5fcacd..fa217dc3f5b341112645c95df44826f19d2105c4 100644 --- a/library/src/tensor_operation_instance/gpu/softmax/device_softmax_f32_f32_instance_rank4_reduce4.cpp +++ b/library/src/tensor_operation_instance/gpu/softmax/device_softmax_f32_f32_instance_rank4_reduce4.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #include @@ -13,12 +13,11 @@ namespace tensor_operation { namespace device { namespace instance { -static constexpr index_t RANK = 4; - void add_device_softmax_f32_f32_rank4_reduce4_instances( - std::vector>& instances) + std::vector>& instances) { - add_device_operation_instances(instances, device_softmax_f32_f32_instances{}); + add_device_operation_instances(instances, device_softmax_f32_f32_generic_instance<4, 4>{}); + add_device_operation_instances(instances, device_softmax_f32_f32_instances<4, 4>{}); } } // namespace instance diff --git a/library/src/tensor_operation_instance/gpu/softmax/device_softmax_i8_i8_instance.cpp b/library/src/tensor_operation_instance/gpu/softmax/device_softmax_i8_i8_instance.cpp deleted file mode 100644 index 608cfcf8380be29e24d5e4b3aadd7573285a5224..0000000000000000000000000000000000000000 --- a/library/src/tensor_operation_instance/gpu/softmax/device_softmax_i8_i8_instance.cpp +++ /dev/null @@ -1,40 +0,0 @@ -// SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. - -#include - -#include "ck/library/tensor_operation_instance/gpu/softmax/device_softmax_i8_i8_instance_rank3_reduce1.hpp" -#include "ck/library/tensor_operation_instance/gpu/softmax/device_softmax_i8_i8_instance_rank3_reduce2.hpp" -#include "ck/library/tensor_operation_instance/gpu/softmax/device_softmax_i8_i8_instance_rank3_reduce3.hpp" - -#include "ck/library/tensor_operation_instance/gpu/softmax/device_softmax_i8_i8_instance_rank4_reduce1.hpp" -#include "ck/library/tensor_operation_instance/gpu/softmax/device_softmax_i8_i8_instance_rank4_reduce2.hpp" -#include "ck/library/tensor_operation_instance/gpu/softmax/device_softmax_i8_i8_instance_rank4_reduce3.hpp" -#include "ck/library/tensor_operation_instance/gpu/softmax/device_softmax_i8_i8_instance_rank4_reduce4.hpp" - -namespace ck { -namespace tensor_operation { -namespace device { -namespace instance { - -void add_device_softmax_i8_i8_rank3_instances( - std::vector>& instances) -{ - add_device_softmax_i8_i8_rank3_reduce1_instances(instances); - add_device_softmax_i8_i8_rank3_reduce2_instances(instances); - add_device_softmax_i8_i8_rank3_reduce3_instances(instances); -} - -void add_device_softmax_i8_i8_rank4_instances( - std::vector>& instances) -{ - add_device_softmax_i8_i8_rank4_reduce1_instances(instances); - add_device_softmax_i8_i8_rank4_reduce2_instances(instances); - add_device_softmax_i8_i8_rank4_reduce3_instances(instances); - add_device_softmax_i8_i8_rank4_reduce4_instances(instances); -} - -} // namespace instance -} // namespace device -} // namespace tensor_operation -} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/softmax/device_softmax_i8_i8_instance_rank3_reduce1.cpp b/library/src/tensor_operation_instance/gpu/softmax/device_softmax_i8_i8_instance_rank3_reduce1.cpp deleted file mode 100644 index 15552dbae5d501c506f4fd14b439a80664a95b66..0000000000000000000000000000000000000000 --- a/library/src/tensor_operation_instance/gpu/softmax/device_softmax_i8_i8_instance_rank3_reduce1.cpp +++ /dev/null @@ -1,27 +0,0 @@ -// SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. - -#include - -#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" -#include "ck/library/tensor_operation_instance/device_operation_instance_factory.hpp" -#include "ck/library/tensor_operation_instance/gpu/softmax/device_softmax_i8_i8_instance_rank3_reduce1.hpp" -#include "ck/library/tensor_operation_instance/gpu/softmax/device_softmax_i8_i8_instance_type.hpp" - -namespace ck { -namespace tensor_operation { -namespace device { -namespace instance { - -static constexpr index_t RANK = 3; - -void add_device_softmax_i8_i8_rank3_reduce1_instances( - std::vector>& instances) -{ - add_device_operation_instances(instances, device_softmax_i8_i8_instances{}); -} - -} // namespace instance -} // namespace device -} // namespace tensor_operation -} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/softmax/device_softmax_i8_i8_instance_rank3_reduce2.cpp b/library/src/tensor_operation_instance/gpu/softmax/device_softmax_i8_i8_instance_rank3_reduce2.cpp deleted file mode 100644 index 67674028860b471888bc87a59e1b7ae751c64eec..0000000000000000000000000000000000000000 --- a/library/src/tensor_operation_instance/gpu/softmax/device_softmax_i8_i8_instance_rank3_reduce2.cpp +++ /dev/null @@ -1,27 +0,0 @@ -// SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. - -#include - -#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" -#include "ck/library/tensor_operation_instance/device_operation_instance_factory.hpp" -#include "ck/library/tensor_operation_instance/gpu/softmax/device_softmax_i8_i8_instance_rank3_reduce2.hpp" -#include "ck/library/tensor_operation_instance/gpu/softmax/device_softmax_i8_i8_instance_type.hpp" - -namespace ck { -namespace tensor_operation { -namespace device { -namespace instance { - -static constexpr index_t RANK = 3; - -void add_device_softmax_i8_i8_rank3_reduce2_instances( - std::vector>& instances) -{ - add_device_operation_instances(instances, device_softmax_i8_i8_instances{}); -} - -} // namespace instance -} // namespace device -} // namespace tensor_operation -} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/softmax/device_softmax_i8_i8_instance_rank3_reduce3.cpp b/library/src/tensor_operation_instance/gpu/softmax/device_softmax_i8_i8_instance_rank3_reduce3.cpp deleted file mode 100644 index 4b33da93c2e1330485a55f8ede6654ab2a03fbb1..0000000000000000000000000000000000000000 --- a/library/src/tensor_operation_instance/gpu/softmax/device_softmax_i8_i8_instance_rank3_reduce3.cpp +++ /dev/null @@ -1,27 +0,0 @@ -// SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. - -#include - -#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" -#include "ck/library/tensor_operation_instance/device_operation_instance_factory.hpp" -#include "ck/library/tensor_operation_instance/gpu/softmax/device_softmax_i8_i8_instance_rank3_reduce3.hpp" -#include "ck/library/tensor_operation_instance/gpu/softmax/device_softmax_i8_i8_instance_type.hpp" - -namespace ck { -namespace tensor_operation { -namespace device { -namespace instance { - -static constexpr index_t RANK = 3; - -void add_device_softmax_i8_i8_rank3_reduce3_instances( - std::vector>& instances) -{ - add_device_operation_instances(instances, device_softmax_i8_i8_instances{}); -} - -} // namespace instance -} // namespace device -} // namespace tensor_operation -} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/softmax/device_softmax_i8_i8_instance_rank4_reduce1.cpp b/library/src/tensor_operation_instance/gpu/softmax/device_softmax_i8_i8_instance_rank4_reduce1.cpp deleted file mode 100644 index fe3b823e889267f69001e15066bbef201afa9813..0000000000000000000000000000000000000000 --- a/library/src/tensor_operation_instance/gpu/softmax/device_softmax_i8_i8_instance_rank4_reduce1.cpp +++ /dev/null @@ -1,27 +0,0 @@ -// SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. - -#include - -#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" -#include "ck/library/tensor_operation_instance/device_operation_instance_factory.hpp" -#include "ck/library/tensor_operation_instance/gpu/softmax/device_softmax_i8_i8_instance_rank4_reduce1.hpp" -#include "ck/library/tensor_operation_instance/gpu/softmax/device_softmax_i8_i8_instance_type.hpp" - -namespace ck { -namespace tensor_operation { -namespace device { -namespace instance { - -static constexpr index_t RANK = 4; - -void add_device_softmax_i8_i8_rank4_reduce1_instances( - std::vector>& instances) -{ - add_device_operation_instances(instances, device_softmax_i8_i8_instances{}); -} - -} // namespace instance -} // namespace device -} // namespace tensor_operation -} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/softmax/device_softmax_i8_i8_instance_rank4_reduce2.cpp b/library/src/tensor_operation_instance/gpu/softmax/device_softmax_i8_i8_instance_rank4_reduce2.cpp deleted file mode 100644 index 8ecdf87d9fec061094f83c3beba01497dcb8e5b8..0000000000000000000000000000000000000000 --- a/library/src/tensor_operation_instance/gpu/softmax/device_softmax_i8_i8_instance_rank4_reduce2.cpp +++ /dev/null @@ -1,27 +0,0 @@ -// SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. - -#include - -#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" -#include "ck/library/tensor_operation_instance/device_operation_instance_factory.hpp" -#include "ck/library/tensor_operation_instance/gpu/softmax/device_softmax_i8_i8_instance_rank4_reduce2.hpp" -#include "ck/library/tensor_operation_instance/gpu/softmax/device_softmax_i8_i8_instance_type.hpp" - -namespace ck { -namespace tensor_operation { -namespace device { -namespace instance { - -static constexpr index_t RANK = 4; - -void add_device_softmax_i8_i8_rank4_reduce2_instances( - std::vector>& instances) -{ - add_device_operation_instances(instances, device_softmax_i8_i8_instances{}); -} - -} // namespace instance -} // namespace device -} // namespace tensor_operation -} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/softmax/device_softmax_i8_i8_instance_rank4_reduce3.cpp b/library/src/tensor_operation_instance/gpu/softmax/device_softmax_i8_i8_instance_rank4_reduce3.cpp deleted file mode 100644 index 3563135204085ff8e44d1378f9c4d9ffd99e7b68..0000000000000000000000000000000000000000 --- a/library/src/tensor_operation_instance/gpu/softmax/device_softmax_i8_i8_instance_rank4_reduce3.cpp +++ /dev/null @@ -1,27 +0,0 @@ -// SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. - -#include - -#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" -#include "ck/library/tensor_operation_instance/device_operation_instance_factory.hpp" -#include "ck/library/tensor_operation_instance/gpu/softmax/device_softmax_i8_i8_instance_rank4_reduce3.hpp" -#include "ck/library/tensor_operation_instance/gpu/softmax/device_softmax_i8_i8_instance_type.hpp" - -namespace ck { -namespace tensor_operation { -namespace device { -namespace instance { - -static constexpr index_t RANK = 4; - -void add_device_softmax_i8_i8_rank4_reduce3_instances( - std::vector>& instances) -{ - add_device_operation_instances(instances, device_softmax_i8_i8_instances{}); -} - -} // namespace instance -} // namespace device -} // namespace tensor_operation -} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/softmax/device_softmax_i8_i8_instance_rank4_reduce4.cpp b/library/src/tensor_operation_instance/gpu/softmax/device_softmax_i8_i8_instance_rank4_reduce4.cpp deleted file mode 100644 index aa21a0bf8a863f433665e01e8229bd179948e1db..0000000000000000000000000000000000000000 --- a/library/src/tensor_operation_instance/gpu/softmax/device_softmax_i8_i8_instance_rank4_reduce4.cpp +++ /dev/null @@ -1,27 +0,0 @@ -// SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. - -#include - -#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" -#include "ck/library/tensor_operation_instance/device_operation_instance_factory.hpp" -#include "ck/library/tensor_operation_instance/gpu/softmax/device_softmax_i8_i8_instance_rank4_reduce4.hpp" -#include "ck/library/tensor_operation_instance/gpu/softmax/device_softmax_i8_i8_instance_type.hpp" - -namespace ck { -namespace tensor_operation { -namespace device { -namespace instance { - -static constexpr index_t RANK = 4; - -void add_device_softmax_i8_i8_rank4_reduce4_instances( - std::vector>& instances) -{ - add_device_operation_instances(instances, device_softmax_i8_i8_instances{}); -} - -} // namespace instance -} // namespace device -} // namespace tensor_operation -} // namespace ck diff --git a/library/src/utility/convolution_parameter.cpp b/library/src/utility/convolution_parameter.cpp index c8712d20939d5f9f1b7b981645e77420c7a5c607..57cedd60199e1c3aa291d44c404d56e89f444e9a 100644 --- a/library/src/utility/convolution_parameter.cpp +++ b/library/src/utility/convolution_parameter.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #include "ck/host_utility/io.hpp" diff --git a/library/src/utility/device_memory.cpp b/library/src/utility/device_memory.cpp index 90f943313b0961bb96d6394e855809c62d050559..61b6326b57da736f5565d2d22e34be9b7fcdec67 100644 --- a/library/src/utility/device_memory.cpp +++ b/library/src/utility/device_memory.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #include "ck/host_utility/hip_check_error.hpp" @@ -10,20 +10,67 @@ DeviceMem::DeviceMem(std::size_t mem_size) : mMemSize(mem_size) hip_check_error(hipMalloc(static_cast(&mpDeviceBuf), mMemSize)); } +void DeviceMem::Realloc(std::size_t mem_size) +{ + if(mpDeviceBuf) + { + hip_check_error(hipFree(mpDeviceBuf)); + } + mMemSize = mem_size; + hip_check_error(hipMalloc(static_cast(&mpDeviceBuf), mMemSize)); +} + void* DeviceMem::GetDeviceBuffer() const { return mpDeviceBuf; } std::size_t DeviceMem::GetBufferSize() const { return mMemSize; } void DeviceMem::ToDevice(const void* p) const { - hip_check_error(hipMemcpy(mpDeviceBuf, const_cast(p), mMemSize, hipMemcpyHostToDevice)); + if(mpDeviceBuf) + { + hip_check_error( + hipMemcpy(mpDeviceBuf, const_cast(p), mMemSize, hipMemcpyHostToDevice)); + } + else + { + throw std::runtime_error("ToDevice with an empty pointer"); + } +} + +void DeviceMem::ToDevice(const void* p, const std::size_t cpySize) const +{ + hip_check_error(hipMemcpy(mpDeviceBuf, const_cast(p), cpySize, hipMemcpyHostToDevice)); } void DeviceMem::FromDevice(void* p) const { - hip_check_error(hipMemcpy(p, mpDeviceBuf, mMemSize, hipMemcpyDeviceToHost)); + if(mpDeviceBuf) + { + hip_check_error(hipMemcpy(p, mpDeviceBuf, mMemSize, hipMemcpyDeviceToHost)); + } + else + { + throw std::runtime_error("FromDevice with an empty pointer"); + } } -void DeviceMem::SetZero() const { hip_check_error(hipMemset(mpDeviceBuf, 0, mMemSize)); } +void DeviceMem::FromDevice(void* p, const std::size_t cpySize) const +{ + hip_check_error(hipMemcpy(p, mpDeviceBuf, cpySize, hipMemcpyDeviceToHost)); +} -DeviceMem::~DeviceMem() { hip_check_error(hipFree(mpDeviceBuf)); } +void DeviceMem::SetZero() const +{ + if(mpDeviceBuf) + { + hip_check_error(hipMemset(mpDeviceBuf, 0, mMemSize)); + } +} + +DeviceMem::~DeviceMem() +{ + if(mpDeviceBuf) + { + hip_check_error(hipFree(mpDeviceBuf)); + } +} diff --git a/library/src/utility/host_tensor.cpp b/library/src/utility/host_tensor.cpp index e34fbc8f345b8ca6dd5b34988a1839c6c59e61bd..7211552641195d68e42ff606ddc1b2c7aab7ba84 100644 --- a/library/src/utility/host_tensor.cpp +++ b/library/src/utility/host_tensor.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #include diff --git a/profiler/README.md b/profiler/README.md index 400a64a3965c956cfd713e6a2273617b12ddfa1e..d03bfa7fc4033c6af1d48d96f0ffb28200187e79 100644 --- a/profiler/README.md +++ b/profiler/README.md @@ -76,3 +76,149 @@ e_m_n: dim 4, lengths {128, 128, 128, 128}, strides {2097152, 16384, 128, 1} .... Best Perf: 211.405 ms, 41.6077 TFlops, 15.2372 GB/s ``` + +## Profile batched gemm multiple D kernels +```bash +#arg1: tensor operation (batched_gemm_multi_d=Batched GEMM multi D); +#arg2: data type (0: fp16; 1: int8) +#arg3: matrix layout (0: A[g, m, k] * B[g, k, n] = C[g, m, n]; +# 1: A[g, m, k] * B[g, n, k] = C[g, m, n]; +# 2: A[g, k, m] * B[g, k, n] = C[g, m, n]; +# 3: A[g, k, m] * B[g, n, k] = C[g, m, n]) +#arg4: verification (0: no; 1: yes) +#arg5: initialization (0: no init; 1: integer value; 2: decimal value) +#arg6: print tensor value (0: no; 1: yes) +#arg7: time kernel (0=n0, 1=yes) +#arg8 to 17: M, N, K, StrideA, StrideB, StrideC, BatchStrideA, BatchStrideB, BatchStrideC, BatchCount + +################ op datatype layout verify init log time M N K StrideA StrideB StrideC BatchStrideA BatchStrideB BatchStrideC BatchCount +./bin/ckProfiler batched_gemm_multi_d 0 1 0 0 0 1 4096 4096 4096 4096 4096 4096 16777216 16777216 16777216 16 +``` + +Result (Radeon RX 6800 XT) +```bash +arg.a_grid_desc_k0_m0_m1_k1_{2048, 4096, 2} +arg.b_grid_desc_k0_n0_n1_k1_{2048, 4096, 2} +arg.e_grid_desc_m_n_{ 4096, 4096} +.... +Best Perf: 58.0306 ms, 37.8942 TFlops, 27.7545 GB/s +## Profile grouped convolution backward data kernels +```bash +# arg1: tensor operation (grouped_conv_bwd_data: Grouped Convolution Backward Data) +# arg2: data type (0: Output fp32, Weight fp32, Input fp32 +# 1: Output fp16, Weight fp16, Input fp16 +# 2: Output bf16, Weight bf16, Input bf16 +# arg3: tensor layout (0: Output[G, N, Hi, Wi, C], Weight[G, K, Y, X, C], Input[G, N, Ho, Wo, K] +# 1: Output[N, Hi, Wi, G, C], Weight[G, K, Y, X, C], Input[N, Ho, Wo, G, K]) +# arg4: verification (0: no, 1: yes) +# arg5: initialization (0: no init, 1: integer value, 2: decimal value) +# arg6: print tensor value (0: no; 1: yes) +# arg7: time kernel (0: no, 1: yes) +# Following arguments (depending on number of spatial dims): +# Number of spatial dimensions (1=Conv1d, 2=Conv2d, 3=Conv3d) +# G, N, K, C, +# , (ie Y, X for 2D) +# , (ie Hi, Wi for 2D) +# , (ie Sy, Sx for 2D) +# , (ie Dy, Dx for 2D) +# , (ie LeftPy, LeftPx for 2D) +# , (ie RightPy, RightPx for 2D) + + ################ op datatype layout verify init log time Ndims G N K C Y X Hi Wi Sy Sx Dy Dx LeftPy LeftPx RightPy RightPx +./bin/ckProfiler grouped_conv_bwd_data 1 0 1 1 0 1 2 32 4 192 192 3 3 28 28 1 1 1 1 1 1 1 1 + + ``` + +Result (MI100, FP16, GNHWC_GKYXC_GNHWK) +``` +out: dim 5, lengths {32, 4, 192, 28, 28}, strides {602112, 150528, 1, 5376, 192} +wei: dim 5, lengths {32, 192, 192, 3, 3}, strides {331776, 1728, 1, 576, 192} +in: dim 5, lengths {32, 4, 192, 28, 28}, strides {602112, 150528, 1, 5376, 192} +.... +Best configuration parameters: +name: DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1<256, 128, 256, 32, 8, 2, Default, 32, 32, 2, 4, 8, 4, 1, 1> +avg_time: 0.768321 +tflops: 86.6679 +GB/s: 127.947 +``` + +## Profile grouped convolution backward weight kernels +```bash +# arg1: tensor operation (grouped_conv_bwd_weight: Grouped Convolution Backward Weight) +# arg2: data type (0: Input fp32, Weight fp32, Output fp32 +# 1: Input fp16, Weight fp16, Output fp16 +# 2: Input bf16, Weight fp32, Output bf16) +# arg3: tensor layout (0: Input[G, N, C, Hi, Wi], Weight[G, K, C, Y, X], Output[G, N, K, Ho, Wo] +# 1: Input[G, N, Hi, Wi, C], Weight[G, K, Y, X, C], Output[G, N, Ho, Wo, K] +# 2: Input[N, Hi, Wi, G, C], Weight[G, K, Y, X, C], Output[N, Ho, Wo, G, K] +# arg4: verification (0: no, 1: yes) +# arg5: initialization (0: no init, 1: integer value, 2: decimal value) +# arg6: print tensor value (0: no; 1: yes) +# arg7: time kernel (0: no, 1: yes) +# Following arguments (depending on number of spatial dims): +# Number of spatial dimensions (1=Conv1d, 2=Conv2d, 3=Conv3d) +# G, N, K, C, +# , (ie Y, X for 2D) +# , (ie Hi, Wi for 2D) +# , (ie Sy, Sx for 2D) +# , (ie Dy, Dx for 2D) +# , (ie LeftPy, LeftPx for 2D) +# , (ie RightPy, RightPx for 2D) +# SplitK + + ################ op datatype layout verify init log time Ndims G N K C Y X Hi Wi Sy Sx Dy Dx LeftPy LeftPx RightPy RightPx SplitK +./bin/ckProfiler grouped_conv_bwd_weight 1 0 1 1 0 1 2 32 256 256 512 3 3 28 28 1 1 1 1 1 0 0 0 1 + + ``` + +Result (MI100, FP16, GNHWC_GKYXC_GNHWK) +``` +input: dim 5, lengths {32, 512, 1024, 28, 28}, strides {411041792, 802816, 1, 28672, 1024} +weight: dim 5, lengths {32, 512, 1024, 3, 3}, strides {4718592, 9216, 1, 3072, 1024} +output: dim 5, lengths {32, 512, 512, 26, 26}, strides {177209344, 346112, 1, 13312, 512} +.... +Best configuration parameters: +name: DeviceGroupedConvBwdWeight_Xdl_CShuffle<256, 256, 128, 4, Default, 8, 4, 2, 8, 4, 8, 2, 1, 1, 8> +avg_time: 68.5216 +tflops: 95.337 +GB/s: 69.2301 +``` +Note: This kernel use atomic add, this will cause output buffer to be accumulated multiple times, causing verification failure. To work around it, do not use CK's own timer and do verification at the same time. + +## Profile image to column kernels +```bash +# arg1: tensor operation (" OP_NAME ": " OP_DESC ") +# arg2: data type (0: Input fp32, Weight fp32, Output fp32 +# 1: Input fp16, Weight fp16, Output fp16 +# 2: Input bf16, Weight bf16, Output bf16 +# 3: Input int8, Weight int8, Output int8) +# arg3: tensor layout (0: Input[N, Hi, Wi, C], Output[N * Ho * Wo, Y * X * C]) +# arg4: verification (0: no, 1: yes) +# arg5: initialization (0: no init, 1: integer value, 2: decimal value) +# arg6: print tensor value (0: no; 1: yes) +# arg7: time kernel (0: no, 1: yes) +# Following arguments (depending on number of spatial dims): +# Number of spatial dimensions (1=Conv1d, 2=Conv2d, 3=Conv3d) +# G, N, K, C, +# , (ie Y, X for 2D) +# , (ie Hi, Wi for 2D) +# , (ie Sy, Sx for 2D) +# , (ie Dy, Dx for 2D) +# , (ie LeftPy, LeftPx for 2D) +# , (ie RightPy, RightPx for 2D) + + ################ op datatype layout verify init log time Ndims G N K C Y X Hi Wi Sy Sx Dy Dx LeftPy LeftPx RightPy RightPx +./bin/ckProfiler image_to_column 0 0 1 1 0 1 2 1 256 1 512 3 3 28 28 1 1 1 1 0 0 0 0 + + ``` + +Result (MI210, FP32, NHWC) +``` +input: dim 5, lengths {1, 256, 512, 28, 28}, strides {102760448, 401408, 1, 14336, 512} +output: dim 2, lengths {173056, 4608}, strides {4608, 1} +.... +Best configuration parameters: +name: DeviceImageToColumn<128, 32, 64, 4> +avg_time: 3.12326 +GB/s: 2042.59 +``` diff --git a/profiler/include/profiler/data_type_enum.hpp b/profiler/include/profiler/data_type_enum.hpp index afcd6fea224f3e06f29309349bcc6ad55caa7d4f..c046c7fabb30e831fe41b31f936febef0f5bf148 100644 --- a/profiler/include/profiler/data_type_enum.hpp +++ b/profiler/include/profiler/data_type_enum.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/profiler/include/profiler/profile_avg_pool3d_bwd_impl.hpp b/profiler/include/profiler/profile_avg_pool3d_bwd_impl.hpp new file mode 100644 index 0000000000000000000000000000000000000000..e7e8f7213f4eeb1d044f21c8d39f85613f290289 --- /dev/null +++ b/profiler/include/profiler/profile_avg_pool3d_bwd_impl.hpp @@ -0,0 +1,253 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include + +#include "ck/ck.hpp" +#include "ck/library/tensor_operation_instance/gpu/pool3d_fwd.hpp" +#include "ck/library/tensor_operation_instance/gpu/avg_pool3d_bwd.hpp" +#include "ck/library/utility/check_err.hpp" +#include "ck/library/utility/device_memory.hpp" +#include "ck/library/utility/host_tensor.hpp" +#include "ck/library/utility/host_tensor_generator.hpp" +#include "ck/library/utility/literals.hpp" +#include "ck/library/reference_tensor_operation/cpu/reference_avgpool_bwd.hpp" + +namespace ck { +namespace profiler { + +template +std::vector f_tensor_strides_ncdhw(ck::index_t N_, + ck::index_t C_, + ck::index_t D, + ck::index_t H, + ck::index_t W, + TensorLayout layout) +{ + using namespace ck::literals; + (void)N_; + if constexpr(ck::is_same::value) + return {D * C_ * H * W, 1_uz, C_ * H * W, W * C_, C_}; + else + throw std::runtime_error("not supported yet"); +}; + +template +bool profile_avg_pool3d_bwd_impl(int do_verification, + int init_method, + bool do_log, + bool time_kernel, + std::vector in_length, // NCDHW + std::vector window_spatial_lengths, + std::vector window_strides, + std::vector window_dilations, + std::vector input_left_pads, + std::vector input_right_pads) +{ + constexpr index_t InOutRank = 5; + constexpr index_t WindowRank = 3; + + if(in_length.size() != InOutRank || window_spatial_lengths.size() != WindowRank || + window_strides.size() != WindowRank || window_dilations.size() != WindowRank || + input_left_pads.size() != WindowRank || input_right_pads.size() != WindowRank) + { + std::cout << "Parameter is incorrect" << std::endl; + return false; + } + + std::vector out_length(InOutRank); + + int N = in_length[0]; + int C = in_length[1]; + + out_length[0] = N; + out_length[1] = C; + + // Calculate Do, Ho, Wo + for(int i = 2; i < InOutRank; ++i) + { + auto pad1 = input_left_pads[i - 2]; + auto pad2 = input_right_pads[i - 2]; + auto windows_size = window_spatial_lengths[i - 2]; + auto windows_stride = window_strides[i - 2]; + auto windows_dilation = window_dilations[i - 2]; + auto eff = (windows_size - 1) * windows_dilation + 1; + out_length[i] = (in_length[i] + pad1 + pad2 - eff) / windows_stride + 1; + } + + int Di = in_length[2]; + int Hi = in_length[3]; + int Wi = in_length[4]; + int Do = out_length[2]; + int Ho = out_length[3]; + int Wo = out_length[4]; + + auto f_host_tensor_descriptor = + [](std::size_t N_, std::size_t C_, std::size_t D, std::size_t H, std::size_t W) { + using namespace ck::literals; + + return HostTensorDescriptor({N_, C_, D, H, W}, + {D * C_ * H * W, 1_uz, C_ * H * W, W * C_, C_}); + }; + + Tensor dout_n_c_do_ho_wo(f_host_tensor_descriptor(N, C, Do, Ho, Wo)); + Tensor din_n_c_di_hi_wi_device(f_host_tensor_descriptor(N, C, Di, Hi, Wi)); + Tensor din_n_c_di_hi_wi_host(f_host_tensor_descriptor(N, C, Di, Hi, Wi)); + + switch(init_method) + { + case 0: dout_n_c_do_ho_wo.GenerateTensorValue(GeneratorTensor_1{}); break; + case 1: dout_n_c_do_ho_wo.GenerateTensorValue(GeneratorTensor_2{-5, 5}); break; + default: dout_n_c_do_ho_wo.GenerateTensorValue(GeneratorTensor_3{-0.5, 0.5}); + } + + DeviceMem dout_device_buf(sizeof(DOutDataType) * dout_n_c_do_ho_wo.mDesc.GetElementSpaceSize()); + DeviceMem din_device_buf(sizeof(DInDataType) * + din_n_c_di_hi_wi_device.mDesc.GetElementSpaceSize()); + + dout_device_buf.ToDevice(dout_n_c_do_ho_wo.mData.data()); + + using DeviceOp = ck::tensor_operation::device:: + DeviceAvgPoolBwd<3, DOutDataType, DInDataType, DOutLayout, DInLayout>; + + // get device op instances + const auto instance_ptrs = + ck::tensor_operation::device::instance::DeviceOperationInstanceFactory< + DeviceOp>::GetInstances(); + + std::cout << "found " << instance_ptrs.size() << " instances" << std::endl; + + std::string best_instance_name; + float best_avg_time = std::numeric_limits::max(); + float best_gb_per_sec = 0; + + if(do_verification) + { + using ReferencePoolingBwdInstance = + ck::tensor_operation::host::ReferenceAvgPoolBwd<3, DInDataType, DOutDataType>; + + ReferencePoolingBwdInstance ref_pooling_bwd; + auto ref_pooling_bwd_argument = ref_pooling_bwd.MakeArgument(din_n_c_di_hi_wi_host, + dout_n_c_do_ho_wo, + window_spatial_lengths, + window_strides, + window_dilations, + input_left_pads, + input_right_pads); + + auto ref_invoker = ref_pooling_bwd.MakeInvoker(); + ref_invoker.Run(ref_pooling_bwd_argument); + } + + int num_kernel = 0; + + for(auto& inst_ptr : instance_ptrs) + { + auto argument_ptr = inst_ptr->MakeArgumentPointer( + static_cast(dout_device_buf.GetDeviceBuffer()), + static_cast(din_device_buf.GetDeviceBuffer()), + {N, C, Do, Ho, Wo}, + {N, C, Di, Hi, Wi}, + f_tensor_strides_ncdhw(N, C, Do, Ho, Wo, DOutLayout{}), + f_tensor_strides_ncdhw(N, C, Di, Hi, Wi, DInLayout{}), + window_spatial_lengths, + window_strides, + window_dilations, + input_left_pads, + input_right_pads); + + if(inst_ptr->IsSupportedArgument(argument_ptr.get())) + { + ++num_kernel; + } + else + { + if(time_kernel) + { + std::cout << inst_ptr->GetTypeString() << " skipped due to unsupported argument: "; + LogRange(std::cout << "doutput lengths = ", out_length, ", ") << std::endl; + } + + continue; + } + + din_device_buf.SetZero(); + + auto invoker_ptr = inst_ptr->MakeInvokerPointer(); + float avg_time = invoker_ptr->Run(argument_ptr.get(), StreamConfig{nullptr, time_kernel}); + + std::size_t num_bytes = + dout_n_c_do_ho_wo.mDesc.GetElementSize() * sizeof(DOutDataType) + + din_n_c_di_hi_wi_device.mDesc.GetElementSize() * sizeof(DInDataType); + + float gb_per_sec = num_bytes / 1.E6 / avg_time; + + if(time_kernel) + std::cout << "Perf: " << std::setw(10) << avg_time << " ms, " << gb_per_sec << " GB/s, " + << inst_ptr->GetTypeString() << std::endl; + + if(avg_time < best_avg_time) + { + best_instance_name = inst_ptr->GetTypeString(); + best_avg_time = avg_time; + best_gb_per_sec = gb_per_sec; + } + + if(do_verification) + { + din_device_buf.FromDevice(din_n_c_di_hi_wi_device.mData.data()); + bool pass = ck::utils::check_err(din_n_c_di_hi_wi_device.mData, + din_n_c_di_hi_wi_host.mData, + "Error: Incorrect results", + 1e-3, + 1e-3); + + if(do_log) + { + LogRangeAsType( + std::cout << "din_n_c_di_hi_wi_device: ", din_n_c_di_hi_wi_device.mData, ",") + << std::endl; + + LogRangeAsType( + std::cout << "din_n_c_di_hi_wi_host: ", din_n_c_di_hi_wi_host.mData, ",") + << std::endl; + } + + if(!pass) + { + std::cout << inst_ptr->GetTypeString() << " failed verification: "; + LogRange(std::cout << "doutput lengths = [", out_length, ", ") << "]." << std::endl; + return false; + } + else + { + if(time_kernel) + std::cout << "pass" << std::endl; + } + } + } + + if(time_kernel) + { + LogRange(std::cout << "length = ", out_length, ",") << std::endl; + std::cout << "best perf = " << best_avg_time << " ms, " << best_gb_per_sec << " GB/s, " + << best_instance_name << std::endl; + } + + if(num_kernel == 0) + { + std::cout << "Error: No kernel is applicable" << std::endl; + return false; + } + + return true; +} + +} // namespace profiler +} // namespace ck diff --git a/profiler/include/profiler/profile_batched_gemm_add_relu_gemm_add_impl.hpp b/profiler/include/profiler/profile_batched_gemm_add_relu_gemm_add_impl.hpp index b16254279ce433d06f71ca350b53841a294a3de6..22dab31100d469938d39de0db89bf5b2257383c8 100644 --- a/profiler/include/profiler/profile_batched_gemm_add_relu_gemm_add_impl.hpp +++ b/profiler/include/profiler/profile_batched_gemm_add_relu_gemm_add_impl.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/profiler/include/profiler/profile_batched_gemm_bias_softmax_gemm_permute_impl.hpp b/profiler/include/profiler/profile_batched_gemm_bias_softmax_gemm_permute_impl.hpp index 799dccc0ff3a8ef253a6153dbdc64d8f7a5c26cb..5bee67c1ce9dc619f2b5208bf350a8fc01dceb06 100644 --- a/profiler/include/profiler/profile_batched_gemm_bias_softmax_gemm_permute_impl.hpp +++ b/profiler/include/profiler/profile_batched_gemm_bias_softmax_gemm_permute_impl.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/profiler/include/profiler/profile_batched_gemm_gemm_impl.hpp b/profiler/include/profiler/profile_batched_gemm_gemm_impl.hpp index 1583c6db21e0959bd155946b2e81363091b4f43f..f3d2c5561756d43a4af44bbde12cf93c26f5602c 100644 --- a/profiler/include/profiler/profile_batched_gemm_gemm_impl.hpp +++ b/profiler/include/profiler/profile_batched_gemm_gemm_impl.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/profiler/include/profiler/profile_batched_gemm_impl.hpp b/profiler/include/profiler/profile_batched_gemm_impl.hpp index c07d7c0555490559bbf0357a5d2657bcbe5261fd..936c22f5d89143198a1d4a01cb06bd5c66ca043e 100644 --- a/profiler/include/profiler/profile_batched_gemm_impl.hpp +++ b/profiler/include/profiler/profile_batched_gemm_impl.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #pragma once @@ -8,9 +8,11 @@ #include "ck/ck.hpp" #include "ck/tensor_operation/gpu/device/tensor_layout.hpp" #include "ck/tensor_operation/gpu/device/device_batched_gemm.hpp" +#include "ck/tensor_operation/gpu/device/device_batched_gemm_multi_d.hpp" #include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" #include "ck/library/tensor_operation_instance/gpu/batched_gemm.hpp" +#include "ck/library/tensor_operation_instance/gpu/batched_gemm_multi_d.hpp" #include "ck/library/utility/check_err.hpp" #include "ck/library/utility/device_memory.hpp" @@ -27,7 +29,11 @@ template + typename CLayout, + typename AElementOp, + typename BElementOp, + typename CElementOp, + typename DeviceOp> bool profile_batched_gemm_impl(int do_verification, int init_method, bool do_log, @@ -88,10 +94,6 @@ bool profile_batched_gemm_impl(int do_verification, b_g_k_n.GenerateTensorValue(GeneratorTensor_3{-0.5, 0.5}); } - using AElementOp = ck::tensor_operation::element_wise::PassThrough; - using BElementOp = ck::tensor_operation::element_wise::PassThrough; - using CElementOp = ck::tensor_operation::element_wise::PassThrough; - const auto a_element_op = AElementOp{}; const auto b_element_op = BElementOp{}; const auto c_element_op = CElementOp{}; @@ -124,16 +126,6 @@ bool profile_batched_gemm_impl(int do_verification, b_device_buf.ToDevice(b_g_k_n.mData.data()); c_device_buf.ToDevice(c_g_m_n_device_result.mData.data()); - using DeviceOp = ck::tensor_operation::device::DeviceBatchedGemm; - // get device op instances const auto op_ptrs = ck::tensor_operation::device::instance::DeviceOperationInstanceFactory< DeviceOp>::GetInstances(); @@ -148,23 +140,62 @@ bool profile_batched_gemm_impl(int do_verification, // profile device op instances for(auto& op_ptr : op_ptrs) { - auto argument_ptr = - op_ptr->MakeArgumentPointer(static_cast(a_device_buf.GetDeviceBuffer()), - static_cast(b_device_buf.GetDeviceBuffer()), - static_cast(c_device_buf.GetDeviceBuffer()), - M, - N, - K, - StrideA, - StrideB, - StrideC, - BatchStrideA, - BatchStrideB, - BatchStrideC, - BatchCount, - ck::tensor_operation::element_wise::PassThrough{}, - ck::tensor_operation::element_wise::PassThrough{}, - ck::tensor_operation::element_wise::PassThrough{}); + std::unique_ptr argument_ptr; + // false branch for multi d dl kernel + if constexpr(std::is_same< + DeviceOp, + ck::tensor_operation::device::DeviceBatchedGemm>::value) + { + + argument_ptr = + op_ptr->MakeArgumentPointer(static_cast(a_device_buf.GetDeviceBuffer()), + static_cast(b_device_buf.GetDeviceBuffer()), + static_cast(c_device_buf.GetDeviceBuffer()), + M, + N, + K, + StrideA, + StrideB, + StrideC, + BatchStrideA, + BatchStrideB, + BatchStrideC, + BatchCount, + ck::tensor_operation::element_wise::PassThrough{}, + ck::tensor_operation::element_wise::PassThrough{}, + ck::tensor_operation::element_wise::PassThrough{}); + } + else + { + argument_ptr = + op_ptr->MakeArgumentPointer(static_cast(a_device_buf.GetDeviceBuffer()), + static_cast(b_device_buf.GetDeviceBuffer()), + {}, + static_cast(c_device_buf.GetDeviceBuffer()), + M, + N, + K, + BatchCount, + StrideA, + StrideB, + {}, + StrideC, + BatchStrideA, + BatchStrideB, + {}, + BatchStrideC, + ck::tensor_operation::element_wise::PassThrough{}, + ck::tensor_operation::element_wise::PassThrough{}, + ck::tensor_operation::element_wise::PassThrough{}); + } auto invoker_ptr = op_ptr->MakeInvokerPointer(); diff --git a/profiler/include/profiler/profile_batched_gemm_reduce_impl.hpp b/profiler/include/profiler/profile_batched_gemm_reduce_impl.hpp index 45b7b77388b63c8abba43aabc50a9c0dd86eeede..901fa338d4d7aab26a70da548e2b000c98840d26 100644 --- a/profiler/include/profiler/profile_batched_gemm_reduce_impl.hpp +++ b/profiler/include/profiler/profile_batched_gemm_reduce_impl.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/profiler/include/profiler/profile_batched_gemm_softmax_gemm_impl.hpp b/profiler/include/profiler/profile_batched_gemm_softmax_gemm_impl.hpp index f5ec235141a7238157c0d2fd34f1f9c849decfc4..15a21206c5945d700069cd7a8f8a9fdff7c4e9bf 100644 --- a/profiler/include/profiler/profile_batched_gemm_softmax_gemm_impl.hpp +++ b/profiler/include/profiler/profile_batched_gemm_softmax_gemm_impl.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/profiler/include/profiler/profile_batched_gemm_softmax_gemm_permute_impl.hpp b/profiler/include/profiler/profile_batched_gemm_softmax_gemm_permute_impl.hpp index 91c28f25fc5953460a00d87961b97d1587aad1ce..f2fcb0b133861b87a92e3434c8c3e22b0ef1a51c 100644 --- a/profiler/include/profiler/profile_batched_gemm_softmax_gemm_permute_impl.hpp +++ b/profiler/include/profiler/profile_batched_gemm_softmax_gemm_permute_impl.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/profiler/include/profiler/profile_batchnorm_backward_impl.hpp b/profiler/include/profiler/profile_batchnorm_backward_impl.hpp index 79d8862081fd78b0706110219dc5af5016cd9122..3343b5e66e4b8666edf778abf4393dd75cbd0937 100644 --- a/profiler/include/profiler/profile_batchnorm_backward_impl.hpp +++ b/profiler/include/profiler/profile_batchnorm_backward_impl.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/profiler/include/profiler/profile_batchnorm_forward_impl.hpp b/profiler/include/profiler/profile_batchnorm_forward_impl.hpp index 82fe75bf015beba4d56898f67aa2100b2d3c6ff4..2f9538b16c326a740b1e3bd0c3c150eaaadb9ebf 100644 --- a/profiler/include/profiler/profile_batchnorm_forward_impl.hpp +++ b/profiler/include/profiler/profile_batchnorm_forward_impl.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/profiler/include/profiler/profile_batchnorm_infer_impl.hpp b/profiler/include/profiler/profile_batchnorm_infer_impl.hpp index ca653393452c8b4fb880fee5041a72a497c9566b..1b31a2aabf5f42a115d83270f91d87f0f74d8a4d 100644 --- a/profiler/include/profiler/profile_batchnorm_infer_impl.hpp +++ b/profiler/include/profiler/profile_batchnorm_infer_impl.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/profiler/include/profiler/profile_conv_bwd_data_impl.hpp b/profiler/include/profiler/profile_conv_bwd_data_impl.hpp index 86d394daf90c79d6f9faf34254901dc8e2bd43e2..52152a90fe5eab2fd80ab156d9080c75b65d9eb0 100644 --- a/profiler/include/profiler/profile_conv_bwd_data_impl.hpp +++ b/profiler/include/profiler/profile_conv_bwd_data_impl.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/profiler/include/profiler/profile_conv_fwd_bias_relu_add_impl.hpp b/profiler/include/profiler/profile_conv_fwd_bias_relu_add_impl.hpp index 1aebef8bb2b2b34a640797ab9e5985d5cb09fbea..436fbdbd759195add8fba952235392d31b9613a0 100644 --- a/profiler/include/profiler/profile_conv_fwd_bias_relu_add_impl.hpp +++ b/profiler/include/profiler/profile_conv_fwd_bias_relu_add_impl.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/profiler/include/profiler/profile_conv_fwd_bias_relu_impl.hpp b/profiler/include/profiler/profile_conv_fwd_bias_relu_impl.hpp index 2bac144334eefc2d39ff63e17637b50d977921ec..808c1a1c901293816bf3fb153c8f8a7d2faaaf15 100644 --- a/profiler/include/profiler/profile_conv_fwd_bias_relu_impl.hpp +++ b/profiler/include/profiler/profile_conv_fwd_bias_relu_impl.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/profiler/include/profiler/profile_conv_fwd_impl.hpp b/profiler/include/profiler/profile_conv_fwd_impl.hpp index 1f3ba8f00714449e51f180879f3b38d01221cefe..bc2eb257970d1b91ab7fbf1a592d18df0ef3d159 100644 --- a/profiler/include/profiler/profile_conv_fwd_impl.hpp +++ b/profiler/include/profiler/profile_conv_fwd_impl.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/profiler/include/profiler/profile_elementwise_layernorm_impl.hpp b/profiler/include/profiler/profile_elementwise_layernorm_impl.hpp index 7707e16b089ef3058dea988188f80e453fcd4e7d..1fd9c811095a823ce2746ddba63431ccdd77f20b 100644 --- a/profiler/include/profiler/profile_elementwise_layernorm_impl.hpp +++ b/profiler/include/profiler/profile_elementwise_layernorm_impl.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/profiler/include/profiler/profile_gemm_add_add_fastgelu_impl.hpp b/profiler/include/profiler/profile_gemm_add_add_fastgelu_impl.hpp index 3cc2ea3b92624da3db43696c181248e55a5f0256..81b8d8ddbf551f23de68beabd8900b4251c97c2a 100644 --- a/profiler/include/profiler/profile_gemm_add_add_fastgelu_impl.hpp +++ b/profiler/include/profiler/profile_gemm_add_add_fastgelu_impl.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/profiler/include/profiler/profile_gemm_add_fastgelu_impl.hpp b/profiler/include/profiler/profile_gemm_add_fastgelu_impl.hpp index d53a6589e0f505ed3beb6a7792c9045c49ee0c68..6f6d881c1e4ef5214fb13d3b7523c2b4c979fd49 100644 --- a/profiler/include/profiler/profile_gemm_add_fastgelu_impl.hpp +++ b/profiler/include/profiler/profile_gemm_add_fastgelu_impl.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/profiler/include/profiler/profile_gemm_add_multiply_impl.hpp b/profiler/include/profiler/profile_gemm_add_multiply_impl.hpp index 40093e774f0bcd770e3bc1434566d5d36c069f89..25871dfb2ec814fb5d0cda406a31ad8602b9dc16 100644 --- a/profiler/include/profiler/profile_gemm_add_multiply_impl.hpp +++ b/profiler/include/profiler/profile_gemm_add_multiply_impl.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/profiler/include/profiler/profile_gemm_add_relu_add_layernorm_impl.hpp b/profiler/include/profiler/profile_gemm_add_relu_add_layernorm_impl.hpp index e1c90f0f5255de2386b252acba5684910edeebf6..4c3d0a045054841b3f88a9c85ba9eaa8fd2538b7 100644 --- a/profiler/include/profiler/profile_gemm_add_relu_add_layernorm_impl.hpp +++ b/profiler/include/profiler/profile_gemm_add_relu_add_layernorm_impl.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/profiler/include/profiler/profile_gemm_bias_add_reduce_impl.hpp b/profiler/include/profiler/profile_gemm_bias_add_reduce_impl.hpp index b4ec78cdf37659b6e838b10a8c235fdd6a0c9c64..c0ffea8a326c8bbc95a9b2d2e46da651089d807a 100644 --- a/profiler/include/profiler/profile_gemm_bias_add_reduce_impl.hpp +++ b/profiler/include/profiler/profile_gemm_bias_add_reduce_impl.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/profiler/include/profiler/profile_gemm_bilinear_impl.hpp b/profiler/include/profiler/profile_gemm_bilinear_impl.hpp index 31bae281c45b2e4e9bbdd8068fd3b084270c45ad..b540e938b5b5ab5798a6a18f86e452caf819c75e 100644 --- a/profiler/include/profiler/profile_gemm_bilinear_impl.hpp +++ b/profiler/include/profiler/profile_gemm_bilinear_impl.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/profiler/include/profiler/profile_gemm_fastgelu_impl.hpp b/profiler/include/profiler/profile_gemm_fastgelu_impl.hpp index f9a544c044f4165fbee8cab9ba10a10616f27e22..3893f8cdc7cfbf957fe28cd2b019092b59c88029 100644 --- a/profiler/include/profiler/profile_gemm_fastgelu_impl.hpp +++ b/profiler/include/profiler/profile_gemm_fastgelu_impl.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/profiler/include/profiler/profile_gemm_impl.hpp b/profiler/include/profiler/profile_gemm_impl.hpp index 9b164104b505d3d3912206b5fa1a6ee6aeb2ef48..eaab5dbcc2c8d3a6d5e1ea5d59c1c292f696d16d 100644 --- a/profiler/include/profiler/profile_gemm_impl.hpp +++ b/profiler/include/profiler/profile_gemm_impl.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/profiler/include/profiler/profile_gemm_multiply_add_impl.hpp b/profiler/include/profiler/profile_gemm_multiply_add_impl.hpp new file mode 100644 index 0000000000000000000000000000000000000000..5fc92d1f8dec63dc80385e7bd241367b0df672d0 --- /dev/null +++ b/profiler/include/profiler/profile_gemm_multiply_add_impl.hpp @@ -0,0 +1,242 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/device_gemm_multiple_d.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" + +#include "ck/library/tensor_operation_instance/gpu/gemm_multiply_add.hpp" + +#include "ck/library/utility/check_err.hpp" +#include "ck/library/utility/device_memory.hpp" +#include "ck/library/utility/host_tensor.hpp" +#include "ck/library/utility/host_tensor_generator.hpp" +#include "ck/library/utility/literals.hpp" +#include "ck/library/reference_tensor_operation/cpu/reference_gemm.hpp" + +namespace ck { +namespace profiler { + +template +bool profile_gemm_multiply_add_impl(int do_verification, + int init_method, + bool /*do_log*/, + bool time_kernel, + int M, + int N, + int K, + int StrideA, + int StrideB, + int StrideD0, + int StrideD1, + int StrideE) +{ + auto f_host_tensor_descriptor = + [](std::size_t row, std::size_t col, std::size_t stride, auto layout) { + using namespace ck::literals; + + if(is_same::value) + { + return HostTensorDescriptor({row, col}, {stride, 1_uz}); + } + else + { + return HostTensorDescriptor({row, col}, {1_uz, stride}); + } + }; + + Tensor a_m_k(f_host_tensor_descriptor(M, K, StrideA, ALayout{})); + Tensor b_k_n(f_host_tensor_descriptor(K, N, StrideB, BLayout{})); + Tensor d0_m_n(f_host_tensor_descriptor(M, N, StrideD0, D0Layout{})); + Tensor d1_m_n(f_host_tensor_descriptor(M, N, StrideD1, D1Layout{})); + Tensor e_m_n_device_result(f_host_tensor_descriptor(M, N, StrideE, ELayout{})); + Tensor e_m_n_host_result(f_host_tensor_descriptor(M, N, StrideE, ELayout{})); + + std::cout << "a_m_k: " << a_m_k.mDesc << std::endl; + std::cout << "b_k_n: " << b_k_n.mDesc << std::endl; + std::cout << "d0_m_n: " << d0_m_n.mDesc << std::endl; + std::cout << "d1_m_n: " << d1_m_n.mDesc << std::endl; + std::cout << "e_m_n: " << e_m_n_device_result.mDesc << std::endl; + + switch(init_method) + { + case 0: break; + case 1: + a_m_k.GenerateTensorValue(GeneratorTensor_2{-5, 5}); + b_k_n.GenerateTensorValue(GeneratorTensor_2{-5, 5}); + d0_m_n.GenerateTensorValue(GeneratorTensor_2{-5, 5}); + d1_m_n.GenerateTensorValue(GeneratorTensor_2{-1, 1}); + break; + default: + a_m_k.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); + b_k_n.GenerateTensorValue(GeneratorTensor_3{-0.5, 0.5}); + d0_m_n.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); + d1_m_n.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); + } + + using PassThrough = ck::tensor_operation::element_wise::PassThrough; + using MultiplyAdd = ck::tensor_operation::element_wise::MultiplyAdd; + + using AElementOp = PassThrough; + using BElementOp = PassThrough; + using CDEElementOp = MultiplyAdd; + + const auto a_element_op = AElementOp{}; + const auto b_element_op = BElementOp{}; + const auto cde_element_op = CDEElementOp{}; + + using DeviceOp = + ck::tensor_operation::device::DeviceGemmMultipleD, + ELayout, + ADataType, + BDataType, + ck::Tuple, + EDataType, + PassThrough, + PassThrough, + CDEElementOp>; + + // get device op instances + const auto op_ptrs = ck::tensor_operation::device::instance::DeviceOperationInstanceFactory< + DeviceOp>::GetInstances(); + + std::cout << "found " << op_ptrs.size() << " instances" << std::endl; + + // run reference + if(do_verification) + { + Tensor c_m_n({M, N}); + + using ReferenceGemmInstance = ck::tensor_operation::host::ReferenceGemm; + + auto ref_gemm = ReferenceGemmInstance{}; + auto ref_invoker = ref_gemm.MakeInvoker(); + + auto ref_argument = + ref_gemm.MakeArgument(a_m_k, b_k_n, c_m_n, a_element_op, b_element_op, PassThrough{}); + + ref_invoker.Run(ref_argument); + + for(int m = 0; m < M; ++m) + { + for(int n = 0; n < N; ++n) + { + cde_element_op(e_m_n_host_result(m, n), c_m_n(m, n), d0_m_n(m, n), d1_m_n(m, n)); + } + } + } + + DeviceMem a_device_buf(sizeof(ADataType) * a_m_k.mDesc.GetElementSpaceSize()); + DeviceMem b_device_buf(sizeof(BDataType) * b_k_n.mDesc.GetElementSpaceSize()); + DeviceMem d0_m_n_device_buf(sizeof(D0DataType) * d0_m_n.mDesc.GetElementSpaceSize()); + DeviceMem d1_m_n_device_buf(sizeof(D1DataType) * d1_m_n.mDesc.GetElementSpaceSize()); + DeviceMem e_device_buf(sizeof(EDataType) * e_m_n_device_result.mDesc.GetElementSpaceSize()); + + a_device_buf.ToDevice(a_m_k.mData.data()); + b_device_buf.ToDevice(b_k_n.mData.data()); + d0_m_n_device_buf.ToDevice(d0_m_n.mData.data()); + d1_m_n_device_buf.ToDevice(d1_m_n.mData.data()); + + std::string best_op_name; + float best_ave_time = 0; + float best_tflops = 0; + float best_gb_per_sec = 0; + + bool pass = true; + + // profile device operation instances + for(auto& op_ptr : op_ptrs) + { + auto argument_ptr = op_ptr->MakeArgumentPointer( + a_device_buf.GetDeviceBuffer(), + b_device_buf.GetDeviceBuffer(), + std::array{d0_m_n_device_buf.GetDeviceBuffer(), + d1_m_n_device_buf.GetDeviceBuffer()}, + e_device_buf.GetDeviceBuffer(), + M, + N, + K, + StrideA, + StrideB, + std::array{StrideD0, StrideD1}, + StrideE, + a_element_op, + b_element_op, + cde_element_op); + + auto invoker_ptr = op_ptr->MakeInvokerPointer(); + + std::string op_name = op_ptr->GetTypeString(); + + if(op_ptr->IsSupportedArgument(argument_ptr.get())) + { + // re-init E to zero before profiling a kernel + e_device_buf.SetZero(); + + float ave_time = + invoker_ptr->Run(argument_ptr.get(), StreamConfig{nullptr, time_kernel}); + + std::size_t flop = std::size_t(2) * M * N * K; + + std::size_t num_btype = + sizeof(ADataType) * M * K + sizeof(BDataType) * K * N + sizeof(EDataType) * M * N; + + float tflops = static_cast(flop) / 1.E9 / ave_time; + + float gb_per_sec = num_btype / 1.E6 / ave_time; + + std::cout << "Perf: " << std::setw(10) << ave_time << " ms, " << tflops << " TFlops, " + << gb_per_sec << " GB/s, " << op_name << std::endl; + + if(tflops > best_tflops) + { + best_op_name = op_name; + best_tflops = tflops; + best_ave_time = ave_time; + best_gb_per_sec = gb_per_sec; + } + + if(do_verification) + { + e_device_buf.FromDevice(e_m_n_device_result.mData.data()); + + pass = pass && ck::utils::check_err(e_m_n_device_result, e_m_n_host_result); + } + } + else + { + std::cout << op_name << " does not support this problem" << std::endl; + } + } + + std::cout << "Best Perf: " << best_ave_time << " ms, " << best_tflops << " TFlops, " + << best_gb_per_sec << " GB/s, " << best_op_name << std::endl; + + return pass; +} + +} // namespace profiler +} // namespace ck diff --git a/profiler/include/profiler/profile_gemm_reduce_impl.hpp b/profiler/include/profiler/profile_gemm_reduce_impl.hpp index 370121a3ccff382539ee34994a05b4b035db5512..ff801e8afd3731e62bd6edb545cce87bbddeae74 100644 --- a/profiler/include/profiler/profile_gemm_reduce_impl.hpp +++ b/profiler/include/profiler/profile_gemm_reduce_impl.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/profiler/include/profiler/profile_gemm_splitk_impl.hpp b/profiler/include/profiler/profile_gemm_splitk_impl.hpp index ab1bce258a22a0c6b686f8559c3c68a592628fa2..fb68bb8811b58209f06532d86b08696b1051153c 100644 --- a/profiler/include/profiler/profile_gemm_splitk_impl.hpp +++ b/profiler/include/profiler/profile_gemm_splitk_impl.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #pragma once @@ -94,7 +94,6 @@ bool profile_gemm_splitk_impl(int do_verification, a_device_buf.ToDevice(a_m_k.mData.data()); b_device_buf.ToDevice(b_k_n.mData.data()); - c_device_buf.SetZero(); using DeviceOp = ck::tensor_operation::device::DeviceGemmSplitKMakeArgumentPointer(static_cast(a_device_buf.GetDeviceBuffer()), - static_cast(b_device_buf.GetDeviceBuffer()), - static_cast(c_device_buf.GetDeviceBuffer()), - M, - N, - K, - StrideA, - StrideB, - StrideC, - a_element_op, - b_element_op, - c_element_op, - KBatch); - - auto invoker_ptr = op_ptr->MakeInvokerPointer(); - - if(op_ptr->IsSupportedArgument(argument_ptr.get())) + std::vector kbatch_list = {1, 2, 4, 8, 12, 16, 20, 24, 32, 36, 40, 60, + 64, 72, 80, 88, 96, 128, 144, 160, 176, 192, 256}; + + if(KBatch > 0) { - // re-init C to zero before profiling next kernel - c_device_buf.SetZero(); + kbatch_list = {KBatch}; + } + + for(std::size_t i = 0; i < kbatch_list.size(); i++) + { + auto kbatch_curr = kbatch_list[i]; + + auto argument_ptr = + op_ptr->MakeArgumentPointer(static_cast(a_device_buf.GetDeviceBuffer()), + static_cast(b_device_buf.GetDeviceBuffer()), + static_cast(c_device_buf.GetDeviceBuffer()), + M, + N, + K, + StrideA, + StrideB, + StrideC, + a_element_op, + b_element_op, + c_element_op, + kbatch_curr); + + auto invoker_ptr = op_ptr->MakeInvokerPointer(); + + if(op_ptr->IsSupportedArgument(argument_ptr.get())) + { - std::string op_name = op_ptr->GetTypeString(); + // re-init C to zero before profiling next kernel + c_device_buf.SetZero(); - float ave_time = - invoker_ptr->Run(argument_ptr.get(), StreamConfig{nullptr, time_kernel}); + invoker_ptr->Run(argument_ptr.get(), StreamConfig{nullptr, false}); - std::size_t flop = std::size_t(2) * M * N * K; + if(do_verification) + { + c_device_buf.FromDevice(c_m_n_device_result.mData.data()); + + pass = pass & ck::utils::check_err(c_m_n_device_result, c_m_n_host_result); + + if(do_log) + { + LogRangeAsType(std::cout << "a : ", a_m_k.mData, ",") << std::endl; + LogRangeAsType(std::cout << "b: ", b_k_n.mData, ",") << std::endl; + LogRangeAsType( + std::cout << "c_host : ", c_m_n_host_result.mData, ",") + << std::endl; + LogRangeAsType( + std::cout << "c_device: ", c_m_n_device_result.mData, ",") + << std::endl; + } + } - std::size_t num_btype = - sizeof(ADataType) * M * K + sizeof(BDataType) * K * N + sizeof(CDataType) * M * N; + std::string op_name = op_ptr->GetTypeString(); - float tflops = static_cast(flop) / 1.E9 / ave_time; + float ave_time = + invoker_ptr->Run(argument_ptr.get(), StreamConfig{nullptr, time_kernel}); - float gb_per_sec = num_btype / 1.E6 / ave_time; + std::size_t flop = std::size_t(2) * M * N * K; - std::cout << "Perf: " << std::setw(10) << ave_time << " ms, " << tflops << " TFlops, " - << gb_per_sec << " GB/s, " << op_name << std::endl; + std::size_t num_btype = sizeof(ADataType) * M * K + sizeof(BDataType) * K * N + + sizeof(CDataType) * M * N; - if(tflops > best_tflops) - { - best_op_name = op_name; - best_tflops = tflops; - best_ave_time = ave_time; - best_gb_per_sec = gb_per_sec; - } + float tflops = static_cast(flop) / 1.E9 / ave_time; - if(do_verification) - { - c_device_buf.FromDevice(c_m_n_device_result.mData.data()); + float gb_per_sec = num_btype / 1.E6 / ave_time; - pass = pass & ck::utils::check_err(c_m_n_device_result, c_m_n_host_result); + std::cout << "Perf: " << std::setw(10) << ave_time << " ms, " << tflops + << " TFlops, " << gb_per_sec << " GB/s, " << op_name << ", KBatch " + << kbatch_curr << std::endl; - if(do_log) + // set softer tolerances for fp8 + if constexpr(is_same_v || is_same_v || + is_same_v) { - LogRangeAsType(std::cout << "a : ", a_m_k.mData, ",") << std::endl; - LogRangeAsType(std::cout << "b: ", b_k_n.mData, ",") << std::endl; - LogRangeAsType(std::cout << "c_host : ", c_m_n_host_result.mData, ",") - << std::endl; - LogRangeAsType(std::cout << "c_device: ", c_m_n_device_result.mData, ",") - << std::endl; + std::string msg = "Error: Incorrect results!"; + double rtol = 1e-1; + double atol = 1e-1; + pass = pass & ck::utils::check_err( + c_m_n_device_result, c_m_n_host_result, msg, rtol, atol); + } + else + { + pass = pass & ck::utils::check_err(c_m_n_device_result, c_m_n_host_result); + } + + if(tflops > best_tflops) + { + best_op_name = op_name; + best_tflops = tflops; + best_ave_time = ave_time; + best_gb_per_sec = gb_per_sec; + best_kbatch = kbatch_curr; } } - } - else - { - std::cout << op_ptr->GetTypeString() << " does not support this problem" << std::endl; + else + { + std::cout << op_ptr->GetTypeString() << " does not support this problem" + << std::endl; + } } } @@ -246,7 +282,7 @@ bool profile_gemm_splitk_impl(int do_verification, } std::cout << " M = " << M << " N = " << N << " K = " << K << " StrideA = " << StrideA - << " StrideB = " << StrideB << " StrideC = " << StrideC << " KBatch = " << KBatch + << " StrideB = " << StrideB << " StrideC = " << StrideC << " KBatch = " << best_kbatch << " : " << best_ave_time << " ms, " << best_tflops << " TFlops, " << best_gb_per_sec << " GB/s, " << best_op_name << std::endl; diff --git a/profiler/include/profiler/profile_gemm_streamk_impl.hpp b/profiler/include/profiler/profile_gemm_streamk_impl.hpp new file mode 100644 index 0000000000000000000000000000000000000000..71b54c1f47c3853f33aaf1f21ac5e7239d691157 --- /dev/null +++ b/profiler/include/profiler/profile_gemm_streamk_impl.hpp @@ -0,0 +1,267 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include +#include +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/device_gemm_streamk.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" + +#include "ck/library/tensor_operation_instance/gpu/gemm_streamk.hpp" + +#include "ck/library/utility/check_err.hpp" +#include "ck/library/utility/device_memory.hpp" +#include "ck/library/utility/host_tensor.hpp" +#include "ck/library/utility/host_tensor_generator.hpp" +#include "ck/library/utility/literals.hpp" +#include "ck/library/reference_tensor_operation/cpu/reference_gemm.hpp" + +namespace ck { +namespace profiler { + +template +bool profile_gemm_streamk_impl(int do_verification, + int init_method, + bool do_log, + bool time_kernel, + int M, + int N, + int K, + int StrideA, + int StrideB, + int StrideC, + uint32_t NumSKBlocks = 0xffffffff) +{ + bool pass = true; + + auto f_host_tensor_descriptor = + [](std::size_t row, std::size_t col, std::size_t stride, auto layout) { + using namespace ck::literals; + + if(is_same::value) + { + return HostTensorDescriptor({row, col}, {stride, 1_uz}); + } + else + { + return HostTensorDescriptor({row, col}, {1_uz, stride}); + } + }; + + Tensor a_m_k(f_host_tensor_descriptor(M, K, StrideA, ALayout{})); + Tensor b_k_n(f_host_tensor_descriptor(K, N, StrideB, BLayout{})); + Tensor c_m_n_host_result(f_host_tensor_descriptor(M, N, StrideC, CLayout{})); + Tensor c_m_n_device_result(f_host_tensor_descriptor(M, N, StrideC, CLayout{})); + + std::cout << "a_m_k: " << a_m_k.mDesc << std::endl; + std::cout << "b_k_n: " << b_k_n.mDesc << std::endl; + std::cout << "c_m_n: " << c_m_n_device_result.mDesc << std::endl; + + switch(init_method) + { + case 0: break; + case 1: + a_m_k.GenerateTensorValue(GeneratorTensor_2{-5, 5}); + b_k_n.GenerateTensorValue(GeneratorTensor_2{-3, 3}); + break; + default: + a_m_k.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); + b_k_n.GenerateTensorValue(GeneratorTensor_3{-0.5, 0.5}); + } + + using AElementOp = ck::tensor_operation::element_wise::PassThrough; + using BElementOp = ck::tensor_operation::element_wise::PassThrough; + using CElementOp = ck::tensor_operation::element_wise::PassThrough; + + const auto a_element_op = AElementOp{}; + const auto b_element_op = BElementOp{}; + const auto c_element_op = CElementOp{}; + + DeviceMem a_device_buf(sizeof(ADataType) * a_m_k.mDesc.GetElementSpaceSize()); + DeviceMem b_device_buf(sizeof(BDataType) * b_k_n.mDesc.GetElementSpaceSize()); + DeviceMem c_device_buf(sizeof(CDataType) * c_m_n_device_result.mDesc.GetElementSpaceSize()); + + a_device_buf.ToDevice(a_m_k.mData.data()); + b_device_buf.ToDevice(b_k_n.mData.data()); + c_device_buf.ToDevice(c_m_n_device_result.mData.data()); + + using DeviceOp = ck::tensor_operation::device::DeviceGemmStreamK; + + // get device op instances + const auto op_ptrs = ck::tensor_operation::device::instance::DeviceOperationInstanceFactory< + DeviceOp>::GetInstances(); + + std::cout << "found " << op_ptrs.size() << " instances, " + << (do_verification ? "with verification" : "without verification") << std::endl; + + // Run reference GEMM + if(do_verification) + { + using ReferenceGemmInstance = ck::tensor_operation::host::ReferenceGemm; + + auto ref_gemm = ReferenceGemmInstance{}; + auto ref_invoker = ref_gemm.MakeInvoker(); + + auto ref_argument = ref_gemm.MakeArgument( + a_m_k, b_k_n, c_m_n_host_result, a_element_op, b_element_op, c_element_op); + + ref_invoker.Run(ref_argument); + } + + std::string best_op_name; + float best_ave_time = 0; + float best_tflops = 0; + float best_gb_per_sec = 0; + + // profile device GEMM instances + for(auto& op_ptr : op_ptrs) + { + auto argument_ptr = + op_ptr->MakeArgumentPointer(static_cast(a_device_buf.GetDeviceBuffer()), + static_cast(b_device_buf.GetDeviceBuffer()), + static_cast(c_device_buf.GetDeviceBuffer()), + M, + N, + K, + StrideA, + StrideB, + StrideC, + a_element_op, + b_element_op, + c_element_op, + NumSKBlocks); + DeviceMem workspace; + std::size_t workspace_size = op_ptr->GetWorkSpaceSize(argument_ptr.get()); + if(workspace_size != 0) + { + workspace.Realloc(workspace_size); + op_ptr->SetWorkSpacePointer(argument_ptr.get(), workspace.GetDeviceBuffer()); + } + + auto invoker_ptr = op_ptr->MakeInvokerPointer(); + + if(op_ptr->IsSupportedArgument(argument_ptr.get())) + { + // re-init C to zero before profiling next kernel + c_device_buf.SetZero(); + + invoker_ptr->Run(argument_ptr.get(), StreamConfig{nullptr, false}); + + if(do_verification) + { + c_device_buf.FromDevice(c_m_n_device_result.mData.data()); + + pass = pass & ck::utils::check_err(c_m_n_device_result, c_m_n_host_result); + + if(do_log) + { + LogRangeAsType(std::cout << "a : ", a_m_k.mData, ",") << std::endl; + LogRangeAsType(std::cout << "b: ", b_k_n.mData, ",") << std::endl; + LogRangeAsType(std::cout << "c_host : ", c_m_n_host_result.mData, ",") + << std::endl; + LogRangeAsType(std::cout << "c_device: ", c_m_n_device_result.mData, ",") + << std::endl; + } + } + + std::string op_name = op_ptr->GetTypeString(); + + float ave_time = + invoker_ptr->Run(argument_ptr.get(), StreamConfig{nullptr, time_kernel}); + + std::size_t flop = std::size_t(2) * M * N * K; + + std::size_t num_btype = + sizeof(ADataType) * M * K + sizeof(BDataType) * K * N + sizeof(CDataType) * M * N; + + float tflops = static_cast(flop) / 1.E9 / ave_time; + + float gb_per_sec = num_btype / 1.E6 / ave_time; + + std::cout << "Perf: " << std::setw(10) << ave_time << " ms, " << tflops << " TFlops, " + << gb_per_sec << " GB/s, " << op_name << std::endl; + + if(tflops > best_tflops) + { + best_op_name = op_name; + best_tflops = tflops; + best_ave_time = ave_time; + best_gb_per_sec = gb_per_sec; + } + } + else + { + std::cout << op_ptr->GetTypeString() << " does not support this problem" << std::endl; + } + } + + if constexpr(is_same::value) + { + std::cout << "Best Perf for datatype = f32"; + } + else if constexpr(is_same::value) + { + std::cout << "Best Perf for datatype = f16"; + } + else if constexpr(is_same::value) + { + std::cout << "Best Perf for datatype = bf16"; + } + else if constexpr(is_same::value) + { + std::cout << "Best Perf for datatype = int8"; + } + + if constexpr(is_same::value) + { + std::cout << " ALayout = RowMajor"; + } + else if constexpr(is_same::value) + { + std::cout << " ALayout = ColumnMajor"; + } + + if constexpr(is_same::value) + { + std::cout << " BLayout = RowMajor"; + } + else if constexpr(is_same::value) + { + std::cout << " BLayout = ColumnMajor"; + } + + std::cout << " M = " << M << " N = " << N << " K = " << K << " StrideA = " << StrideA + << " StrideB = " << StrideB << " StrideC = " << StrideC << " : " << best_ave_time + << " ms, " << best_tflops << " TFlops, " << best_gb_per_sec << " GB/s, " + << best_op_name << std::endl; + + return pass; +} + +} // namespace profiler +} // namespace ck diff --git a/profiler/include/profiler/profile_grouped_conv_bwd_data_impl.hpp b/profiler/include/profiler/profile_grouped_conv_bwd_data_impl.hpp new file mode 100644 index 0000000000000000000000000000000000000000..93d3430bba30448b4cd1e62d402bc251e1116191 --- /dev/null +++ b/profiler/include/profiler/profile_grouped_conv_bwd_data_impl.hpp @@ -0,0 +1,257 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#include +#include +#include +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" +#include "ck/tensor_operation/gpu/device/device_grouped_conv_bwd_data_multiple_d.hpp" + +#include "ck/library/utility/check_err.hpp" +#include "ck/library/utility/device_memory.hpp" +#include "ck/library/utility/host_tensor.hpp" +#include "ck/library/utility/host_tensor_generator.hpp" +#include "ck/library/utility/convolution_parameter.hpp" +#include "ck/library/utility/convolution_host_tensor_descriptor_helper.hpp" +#include "ck/library/reference_tensor_operation/cpu/reference_conv_bwd_data.hpp" +#include "ck/library/tensor_operation_instance/gpu/grouped_convolution_backward_data.hpp" + +namespace ck { +namespace profiler { + +template +bool profile_grouped_conv_bwd_data_impl(int do_verification, + int init_method, + bool do_log, + bool time_kernel, + const ck::utils::conv::ConvParam& conv_param) +{ + using OutElementOp = ck::tensor_operation::element_wise::PassThrough; + using WeiElementOp = ck::tensor_operation::element_wise::PassThrough; + using InElementOp = ck::tensor_operation::element_wise::PassThrough; + + const auto out_element_op = OutElementOp{}; + const auto wei_element_op = WeiElementOp{}; + const auto in_element_op = InElementOp{}; + + const auto out_g_n_k_wos_desc = + ck::utils::conv::make_output_host_tensor_descriptor_g_n_k_wos_packed(conv_param); + + const auto wei_g_k_c_xs_desc = + ck::utils::conv::make_weight_host_tensor_descriptor_g_k_c_xs_packed(conv_param); + + const auto in_g_n_c_wis_desc = + ck::utils::conv::make_input_host_tensor_descriptor_g_n_c_wis_packed(conv_param); + + Tensor out(out_g_n_k_wos_desc); + Tensor wei(wei_g_k_c_xs_desc); + Tensor in_host(in_g_n_c_wis_desc); + Tensor in_device(in_g_n_c_wis_desc); + + std::cout << "out: " << out.mDesc << std::endl; + std::cout << "wei: " << wei.mDesc << std::endl; + std::cout << "in: " << in_host.mDesc << std::endl; + + switch(init_method) + { + case 0: break; + case 1: + out.GenerateTensorValue(GeneratorTensor_2{-5, 5}); + wei.GenerateTensorValue(GeneratorTensor_2{-5, 5}); + break; + case 2: + out.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); + wei.GenerateTensorValue(GeneratorTensor_3{-0.5, 0.5}); + break; + default: + out.GenerateTensorValue(GeneratorTensor_1{1}); + wei.GenerateTensorValue(GeneratorTensor_1{1}); + } + + DeviceMem out_device_buf(sizeof(OutDataType) * out.mDesc.GetElementSpaceSize()); + DeviceMem wei_device_buf(sizeof(WeiDataType) * wei.mDesc.GetElementSpaceSize()); + DeviceMem in_device_buf(sizeof(InDataType) * in_device.mDesc.GetElementSpaceSize()); + + out_device_buf.ToDevice(out.mData.data()); + wei_device_buf.ToDevice(wei.mData.data()); + + // reset input to zero + in_device_buf.SetZero(); + + if(do_verification) + { + auto ref_conv = ck::tensor_operation::host::ReferenceConvBwdData(); + + auto ref_invoker = ref_conv.MakeInvoker(); + + in_host.SetZero(); + + auto ref_argument = ref_conv.MakeArgument(in_host, + wei, + out, + conv_param.conv_filter_strides_, + conv_param.conv_filter_dilations_, + conv_param.input_left_pads_, + conv_param.input_right_pads_, + out_element_op, + wei_element_op, + in_element_op); + + ref_invoker.Run(ref_argument); + } + + std::string best_op_name; + float best_avg_time = 0; + float best_tflops = 0; + float best_gb_per_sec = 0; + + // profile device op instances + bool pass = true; + + auto run_impl = [&](auto& op_ptr, auto& argument_ptr) { + if(op_ptr->IsSupportedArgument(argument_ptr.get())) + { + // re-init output to zero before profiling next kernel + in_device_buf.SetZero(); + + std::string op_name = op_ptr->GetTypeString(); + + auto invoker_ptr = op_ptr->MakeInvokerPointer(); + + float avg_time = + invoker_ptr->Run(argument_ptr.get(), StreamConfig{nullptr, time_kernel}); + + std::size_t flop = conv_param.GetFlops(); + std::size_t num_btype = conv_param.GetByte(); + + float tflops = static_cast(flop) / 1.E9 / avg_time; + + float gb_per_sec = num_btype / 1.E6 / avg_time; + + std::cout << "Perf: " << std::setw(10) << avg_time << " ms, " << tflops << " TFlops, " + << gb_per_sec << " GB/s, " << op_name << std::endl; + + if(tflops > best_tflops) + { + best_op_name = op_name; + best_tflops = tflops; + best_avg_time = avg_time; + best_gb_per_sec = gb_per_sec; + } + + if(do_verification) + { + in_device_buf.FromDevice(in_device.mData.data()); + + pass = pass & ck::utils::check_err(in_device, in_host); + + if(do_log) + { + LogRangeAsType(std::cout << "output : ", out.mData, ",") << std::endl; + LogRangeAsType(std::cout << "weight: ", wei.mData, ",") << std::endl; + LogRangeAsType(std::cout << "in_host : ", in_host.mData, ",") + << std::endl; + LogRangeAsType(std::cout << "in_device: ", in_device.mData, ",") + << std::endl; + } + } + } + else + { + std::cout << op_ptr->GetTypeString() << " does not support this problem" << std::endl; + } + }; + + // do GEMM + using DeviceOp = ck::tensor_operation::device::DeviceGroupedConvBwdDataMultipleD, + InLayout, + OutDataType, + WeiDataType, + ck::Tuple<>, + InDataType, + OutElementOp, + WeiElementOp, + InElementOp>; + + // get device op instances + const auto op_ptrs = ck::tensor_operation::device::instance::DeviceOperationInstanceFactory< + DeviceOp>::GetInstances(); + + std::array out_lengths{}; + std::array out_strides{}; + std::array wei_lengths{}; + std::array wei_strides{}; + std::array in_lengths{}; + std::array in_strides{}; + std::array conv_filter_strides{}; + std::array conv_filter_dilations{}; + std::array input_left_pads{}; + std::array input_right_pads{}; + + auto copy = [](const auto& x, auto& y) { ck::ranges::copy(x, y.begin()); }; + + copy(out_g_n_k_wos_desc.GetLengths(), out_lengths); + copy(out_g_n_k_wos_desc.GetStrides(), out_strides); + copy(wei_g_k_c_xs_desc.GetLengths(), wei_lengths); + copy(wei_g_k_c_xs_desc.GetStrides(), wei_strides); + copy(in_g_n_c_wis_desc.GetLengths(), in_lengths); + copy(in_g_n_c_wis_desc.GetStrides(), in_strides); + copy(conv_param.conv_filter_strides_, conv_filter_strides); + copy(conv_param.conv_filter_dilations_, conv_filter_dilations); + copy(conv_param.input_left_pads_, input_left_pads); + copy(conv_param.input_right_pads_, input_right_pads); + + for(auto& op_ptr : op_ptrs) + { + auto argument_ptr = + op_ptr->MakeArgumentPointer(static_cast(out_device_buf.GetDeviceBuffer()), + static_cast(wei_device_buf.GetDeviceBuffer()), + {}, + static_cast(in_device_buf.GetDeviceBuffer()), + out_lengths, + out_strides, + wei_lengths, + wei_strides, + {}, + {}, + in_lengths, + in_strides, + conv_filter_strides, + conv_filter_dilations, + input_left_pads, + input_right_pads, + out_element_op, + wei_element_op, + in_element_op); + + run_impl(op_ptr, argument_ptr); + } + + std::cout << "Best configuration parameters:" + << "\nname: " << best_op_name << "\navg_time: " << best_avg_time + << "\ntflops: " << best_tflops << "\nGB/s: " << best_gb_per_sec << std::endl; + + return pass; +} + +} // namespace profiler +} // namespace ck diff --git a/profiler/include/profiler/profile_grouped_conv_bwd_weight_impl.hpp b/profiler/include/profiler/profile_grouped_conv_bwd_weight_impl.hpp index 4f9aa98376c9e56058f04938f8171ac93c230e3e..48bf639a7090c777c649982e4601eda391a07db2 100644 --- a/profiler/include/profiler/profile_grouped_conv_bwd_weight_impl.hpp +++ b/profiler/include/profiler/profile_grouped_conv_bwd_weight_impl.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #pragma once @@ -136,9 +136,12 @@ bool profile_grouped_conv_bwd_weight_impl(int do_verification, // profile device Conv instances bool all_pass = true; - std::array input_spatial_lengths{}; - std::array filter_spatial_lengths{}; - std::array output_spatial_lengths{}; + std::array input_lengths{}; + std::array filter_lengths{}; + std::array output_lengths{}; + std::array input_strides{}; + std::array weights_strides{}; + std::array output_strides{}; std::array conv_filter_strides{}; std::array conv_filter_dilations{}; std::array input_left_pads{}; @@ -146,9 +149,12 @@ bool profile_grouped_conv_bwd_weight_impl(int do_verification, auto range_copy = [](const auto& from, auto to) { std::copy(begin(from), end(from), to); }; - range_copy(conv_param.input_spatial_lengths_, begin(input_spatial_lengths)); - range_copy(conv_param.filter_spatial_lengths_, begin(filter_spatial_lengths)); - range_copy(conv_param.output_spatial_lengths_, begin(output_spatial_lengths)); + range_copy(in_g_n_c_wis_desc.GetLengths(), begin(input_lengths)); + range_copy(in_g_n_c_wis_desc.GetStrides(), begin(input_strides)); + range_copy(wei_g_k_c_xs_desc.GetLengths(), begin(filter_lengths)); + range_copy(wei_g_k_c_xs_desc.GetStrides(), begin(weights_strides)); + range_copy(out_g_n_k_wos_desc.GetLengths(), begin(output_lengths)); + range_copy(out_g_n_k_wos_desc.GetStrides(), begin(output_strides)); range_copy(conv_param.conv_filter_strides_, begin(conv_filter_strides)); range_copy(conv_param.conv_filter_dilations_, begin(conv_filter_dilations)); range_copy(conv_param.input_left_pads_, begin(input_left_pads)); @@ -160,13 +166,12 @@ bool profile_grouped_conv_bwd_weight_impl(int do_verification, op_ptr->MakeArgumentPointer(static_cast(in_device_buf.GetDeviceBuffer()), static_cast(wei_device_buf.GetDeviceBuffer()), static_cast(out_device_buf.GetDeviceBuffer()), - conv_param.G_, - conv_param.N_, - conv_param.K_, - conv_param.C_, - input_spatial_lengths, - filter_spatial_lengths, - output_spatial_lengths, + input_lengths, + input_strides, + filter_lengths, + weights_strides, + output_lengths, + output_strides, conv_filter_strides, conv_filter_dilations, input_left_pads, diff --git a/profiler/include/profiler/profile_grouped_conv_fwd_impl.hpp b/profiler/include/profiler/profile_grouped_conv_fwd_impl.hpp index b201a2ed331cbf4a0e3951d8d69027f018af09f6..8d37c28881cd0c5edc9af45fb665a8f4bfcc7d5e 100644 --- a/profiler/include/profiler/profile_grouped_conv_fwd_impl.hpp +++ b/profiler/include/profiler/profile_grouped_conv_fwd_impl.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #pragma once @@ -215,7 +215,7 @@ bool profile_grouped_conv_fwd_impl(int do_verification, const auto op_ptrs = ck::tensor_operation::device::instance::DeviceOperationInstanceFactory< DeviceOp>::GetInstances(); - std::cout << "xdl found " << op_ptrs.size() << " instances" << std::endl; + std::cout << "ckProfiler found " << op_ptrs.size() << " instances" << std::endl; for(auto& op_ptr : op_ptrs) { diff --git a/profiler/include/profiler/profile_grouped_gemm_fastgelu_impl.hpp b/profiler/include/profiler/profile_grouped_gemm_fastgelu_impl.hpp index 87e6ae44c76e56919fb17ad232586ae6294f1b68..f05b13b7495dfefee263f51d5404dda3dbc598dc 100644 --- a/profiler/include/profiler/profile_grouped_gemm_fastgelu_impl.hpp +++ b/profiler/include/profiler/profile_grouped_gemm_fastgelu_impl.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/profiler/include/profiler/profile_grouped_gemm_impl.hpp b/profiler/include/profiler/profile_grouped_gemm_impl.hpp index 9abb5e7a53c860a2876f1c2a8c831a8963a4939a..fe7a3976064c8c4241597f48f8b8e4d17769cf21 100644 --- a/profiler/include/profiler/profile_grouped_gemm_impl.hpp +++ b/profiler/include/profiler/profile_grouped_gemm_impl.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #pragma once @@ -70,6 +70,7 @@ bool profile_grouped_gemm_impl(int do_verification, std::vector> a_m_k; std::vector> b_k_n; + std::vector> c_m_n_host_results; std::vector> c_m_n_device_results; for(std::size_t i = 0; i < group_count; i++) @@ -81,6 +82,9 @@ bool profile_grouped_gemm_impl(int do_verification, c_m_n_device_results.push_back( Tensor(f_host_tensor_descriptor(Ms[i], Ns[i], StrideCs[i], CLayout{}))); + + c_m_n_host_results.push_back( + Tensor(f_host_tensor_descriptor(Ms[i], Ns[i], StrideCs[i], CLayout{}))); #if DEBUG_LOG std::cout << "group: " << i << " a_m_k[" << i << "]:" << a_m_k[i].mDesc << ", b_k_n[" << i << "]:" << b_k_n[i].mDesc << ", c_m_n_device_results[" << i @@ -137,7 +141,6 @@ bool profile_grouped_gemm_impl(int do_verification, a_device_buf[i]->ToDevice(a_m_k[i].mData.data()); b_device_buf[i]->ToDevice(b_k_n[i].mData.data()); - c_device_buf[i]->SetZero(); gemm_descs.push_back({Ms[i], Ns[i], Ks[i], StrideAs[i], StrideBs[i], StrideCs[i], {}}); @@ -170,9 +173,36 @@ bool profile_grouped_gemm_impl(int do_verification, float best_ave_time = 0; float best_tflops = 0; float best_gb_per_sec = 0; + float best_kbatch = 0; auto p_ds = std::vector>{}; + if(do_verification) + { + for(std::size_t i = 0; i < gemm_descs.size(); i++) + { + using ReferenceGemmInstance = ck::tensor_operation::host::ReferenceGemm; + + auto ref_gemm = ReferenceGemmInstance{}; + auto ref_invoker = ref_gemm.MakeInvoker(); + + auto ref_argument = ref_gemm.MakeArgument(a_m_k[i], + b_k_n[i], + c_m_n_host_results[i], + a_element_op, + b_element_op, + c_element_op); + + ref_invoker.Run(ref_argument); + } + } + // profile device GEMM instances for(auto& gemm_ptr : op_ptrs) { @@ -193,139 +223,135 @@ bool profile_grouped_gemm_impl(int do_verification, gemm_ptr->SetWorkSpacePointer(argument_ptr.get(), gemm_desc_workspace.GetDeviceBuffer()); std::string gemm_name = gemm_ptr->GetTypeString(); - if(kbatch > 1) + using DeviceOpSplitK = ck::tensor_operation::device::DeviceGroupedGemmSplitK, + CLayout, + ADataType, + BDataType, + ck::Tuple<>, + CDataType, + AElementOp, + BElementOp, + CElementOp>; + + // skip non-splitk grouped_gemm + if(dynamic_cast(gemm_ptr.get()) == nullptr) { - using DeviceOpSplitK = - ck::tensor_operation::device::DeviceGroupedGemmSplitK, - CLayout, - ADataType, - BDataType, - ck::Tuple<>, - CDataType, - AElementOp, - BElementOp, - CElementOp>; - - if(dynamic_cast(gemm_ptr.get()) != nullptr) - { - dynamic_cast(gemm_ptr.get()) - ->SetKBatchSize(argument_ptr.get(), kbatch); - } + continue; } - if(gemm_ptr->IsSupportedArgument(argument_ptr.get())) + std::vector kbatch_list = {1, 2, 4, 8, 12, 16, 20, 24, 32, 48, 64}; + + if(kbatch > 0) { + kbatch_list = {kbatch}; + } - float ave_time = - invoker_ptr->Run(argument_ptr.get(), StreamConfig{nullptr, time_kernel}); + for(std::size_t j = 0; j < kbatch_list.size(); j++) + { + + auto kbatch_curr = kbatch_list[j]; + + dynamic_cast(gemm_ptr.get()) + ->SetKBatchSize(argument_ptr.get(), kbatch_curr); - if(time_kernel) + if(gemm_ptr->IsSupportedArgument(argument_ptr.get())) { - std::size_t flop = 0, num_btype = 0; for(std::size_t i = 0; i < gemm_descs.size(); i++) - { - flop += std::size_t(2) * Ms[i] * Ns[i] * Ks[i]; + c_device_buf[i]->SetZero(); - num_btype += sizeof(ADataType) * Ms[i] * Ks[i] + - sizeof(BDataType) * Ks[i] * Ns[i] + - sizeof(CDataType) * Ms[i] * Ns[i]; - } + invoker_ptr->Run(argument_ptr.get(), StreamConfig{nullptr, false}); - float tflops = static_cast(flop) / 1.E9 / ave_time; + if(do_verification) + { + bool instance_pass = true; + for(std::size_t i = 0; i < gemm_descs.size(); i++) + { - float gb_per_sec = num_btype / 1.E6 / ave_time; - std::cout << "Perf: " << std::setw(10) << ave_time << " ms, " << tflops - << " TFlops, " << gb_per_sec << " GB/s, " << gemm_name << std::endl; + c_device_buf[i]->FromDevice(c_m_n_device_results[i].mData.data()); + + if(std::is_same_v && kbatch_curr > 1) + { + instance_pass = + instance_pass && ck::utils::check_err(c_m_n_device_results[i], + c_m_n_host_results[i], + "Error: Incorrect results!", + 0.06); + } + else + { + instance_pass = + instance_pass && ck::utils::check_err(c_m_n_device_results[i], + c_m_n_host_results[i]); + } + + if(do_log) + { + LogRangeAsType(std::cout << "a : ", a_m_k[i].mData, ",") + << std::endl; + LogRangeAsType(std::cout << "b: ", b_k_n[i].mData, ",") + << std::endl; + LogRangeAsType( + std::cout << "c_device: ", c_m_n_device_results[i].mData, ",") + << std::endl; + LogRangeAsType( + std::cout << "c_host : ", c_m_n_host_results[i].mData, ",") + << std::endl; + } + } - if(tflops > best_tflops) - { - best_gemm_name = gemm_name; - best_tflops = tflops; - best_ave_time = ave_time; - best_gb_per_sec = gb_per_sec; - } - } + std::cout << "Instance: " << gemm_name << " verification " + << (instance_pass ? "SUCCEED" : "FAILED") << std::endl; - if(do_verification) - { - bool instance_pass = true; - for(std::size_t i = 0; i < gemm_descs.size(); i++) - { + pass = pass && instance_pass; + } - c_device_buf[i]->FromDevice(c_m_n_device_results[i].mData.data()); - c_device_buf[i]->SetZero(); + float ave_time = + invoker_ptr->Run(argument_ptr.get(), StreamConfig{nullptr, time_kernel}); - Tensor c_m_n_host_result( - f_host_tensor_descriptor(Ms[i], Ns[i], StrideCs[i], CLayout{})); - - using ReferenceGemmInstance = - ck::tensor_operation::host::ReferenceGemm; - - auto ref_gemm = ReferenceGemmInstance{}; - auto ref_invoker = ref_gemm.MakeInvoker(); - - auto ref_argument = ref_gemm.MakeArgument(a_m_k[i], - b_k_n[i], - c_m_n_host_result, - a_element_op, - b_element_op, - c_element_op); - - ref_invoker.Run(ref_argument); - if(std::is_same_v && kbatch > 1) - { - instance_pass = - instance_pass && ck::utils::check_err(c_m_n_device_results[i], - c_m_n_host_result, - "Error: Incorrect results!", - 0.06); - } - else + if(time_kernel) + { + std::size_t flop = 0, num_btype = 0; + for(std::size_t i = 0; i < gemm_descs.size(); i++) { - instance_pass = - instance_pass && - ck::utils::check_err(c_m_n_device_results[i], c_m_n_host_result); + flop += std::size_t(2) * Ms[i] * Ns[i] * Ks[i]; + + num_btype += sizeof(ADataType) * Ms[i] * Ks[i] + + sizeof(BDataType) * Ks[i] * Ns[i] + + sizeof(CDataType) * Ms[i] * Ns[i]; } - if(do_log) + float tflops = static_cast(flop) / 1.E9 / ave_time; + + float gb_per_sec = num_btype / 1.E6 / ave_time; + std::cout << "Perf: " << std::setw(10) << ave_time << " ms, " << tflops + << " TFlops, " << gb_per_sec << " GB/s, " << gemm_name << ", KBatch " + << kbatch_curr << std::endl; + + if(tflops > best_tflops) { - LogRangeAsType(std::cout << "a : ", a_m_k[i].mData, ",") - << std::endl; - LogRangeAsType(std::cout << "b: ", b_k_n[i].mData, ",") << std::endl; - LogRangeAsType( - std::cout << "c_device: ", c_m_n_device_results[i].mData, ",") - << std::endl; - LogRangeAsType( - std::cout << "c_host : ", c_m_n_host_result.mData, ",") - << std::endl; + best_gemm_name = gemm_name; + best_tflops = tflops; + best_ave_time = ave_time; + best_gb_per_sec = gb_per_sec; + best_kbatch = kbatch_curr; } } - - std::cout << "Instance: " << gemm_name << " verification " - << (instance_pass ? "SUCCEED" : "FAILED") << std::endl; - - pass = pass && instance_pass; } - } - else - { - std::cout << "Instance: " << gemm_name << ", does not support this GEMM problem" - << std::endl; + else + { + std::cout << "Instance: " << gemm_name << ", does not support this GEMM problem" + << std::endl; + } } } if(time_kernel) { std::cout << "Best Perf: " << best_ave_time << " ms, " << best_tflops << " TFlops, " - << best_gb_per_sec << " GB/s, " << best_gemm_name << std::endl; + << best_gb_per_sec << " GB/s, " << best_gemm_name << ", KBatch = " << best_kbatch + << std::endl; } return pass; diff --git a/profiler/include/profiler/profile_groupnorm_impl.hpp b/profiler/include/profiler/profile_groupnorm_impl.hpp index 73343f6bec22db95ccaba37cb5c568018f630944..f88ba8453c7f9a86e85e9d617a3c6123ebd6b7d8 100644 --- a/profiler/include/profiler/profile_groupnorm_impl.hpp +++ b/profiler/include/profiler/profile_groupnorm_impl.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #pragma once @@ -139,6 +139,10 @@ bool profile_groupnorm_impl(int do_verification, continue; } + size_t workspace_sz = inst_ptr->GetWorkSpaceSize(argument_ptr.get()); + DeviceMem workspace_dev(workspace_sz); + inst_ptr->SetWorkSpacePointer(argument_ptr.get(), workspace_dev.GetDeviceBuffer()); + auto invoker_ptr = inst_ptr->MakeInvokerPointer(); float avg_time = invoker_ptr->Run(argument_ptr.get(), StreamConfig{nullptr, time_kernel}); diff --git a/profiler/include/profiler/profile_image_to_column_impl.hpp b/profiler/include/profiler/profile_image_to_column_impl.hpp new file mode 100644 index 0000000000000000000000000000000000000000..cc929e922016cdfce4897985f9b61408617c63d0 --- /dev/null +++ b/profiler/include/profiler/profile_image_to_column_impl.hpp @@ -0,0 +1,200 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include +#include +#include +#include +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/device_image_to_column.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_image_to_column_impl.hpp" +#include "ck/library/tensor_operation_instance/gpu/image_to_column.hpp" +#include "ck/library/utility/check_err.hpp" +#include "ck/library/utility/device_memory.hpp" +#include "ck/library/utility/host_tensor.hpp" +#include "ck/library/utility/host_tensor_generator.hpp" +#include "ck/library/utility/convolution_parameter.hpp" +#include "ck/library/utility/convolution_host_tensor_descriptor_helper.hpp" +#include "ck/library/reference_tensor_operation/cpu/reference_image_to_column.hpp" + +namespace ck { +namespace profiler { + +template +using S = ck::Sequence; + +template +bool profile_image_to_column_impl(int do_verification, + int init_method, + bool do_log, + bool time_kernel, + const ck::utils::conv::ConvParam& conv_param) +{ + const ck::index_t NDoHoWo = + conv_param.N_ * + ck::accumulate_n( + conv_param.output_spatial_lengths_.begin(), NDimSpatial, 1, std::multiplies<>()); + const ck::index_t CZYX = + conv_param.C_ * + ck::accumulate_n( + conv_param.filter_spatial_lengths_.begin(), NDimSpatial, 1, std::multiplies<>()); + + const auto in_desc = + ck::utils::conv::make_input_host_tensor_descriptor_g_n_c_wis_packed( + conv_param); + const auto out_desc = HostTensorDescriptor({NDoHoWo, CZYX}); + + std::array input_spatial_lengths{}; + std::array filter_spatial_lengths{}; + std::array output_spatial_lengths{}; + std::array input_g_n_c_wis_strides{}; + std::array output_m_k_strides{}; + std::array conv_filter_strides{}; + std::array conv_filter_dilations{}; + std::array input_left_pads{}; + std::array input_right_pads{}; + + auto copy = [](const auto& x, auto& y) { std::copy(x.begin(), x.end(), y.begin()); }; + + copy(conv_param.input_spatial_lengths_, input_spatial_lengths); + copy(conv_param.filter_spatial_lengths_, filter_spatial_lengths); + copy(conv_param.output_spatial_lengths_, output_spatial_lengths); + copy(in_desc.GetStrides(), input_g_n_c_wis_strides); + copy(out_desc.GetStrides(), output_m_k_strides); + copy(conv_param.conv_filter_strides_, conv_filter_strides); + copy(conv_param.conv_filter_dilations_, conv_filter_dilations); + copy(conv_param.input_left_pads_, input_left_pads); + copy(conv_param.input_right_pads_, input_right_pads); + + Tensor input(in_desc); + Tensor host_output(out_desc); + Tensor device_output(out_desc); + + std::cout << "input: " << input.mDesc << std::endl; + std::cout << "output: " << host_output.mDesc << std::endl; + + switch(init_method) + { + case 0: break; + case 1: input.GenerateTensorValue(GeneratorTensor_2{-5, 5}); break; + default: input.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); + } + + DeviceMem in_device_buf(sizeof(InputDataType) * input.mDesc.GetElementSpaceSize()); + DeviceMem out_device_buf(sizeof(OutputDataType) * device_output.mDesc.GetElementSpaceSize()); + + in_device_buf.ToDevice(input.mData.data()); + + // run reference op + if(do_verification) + { + auto ref_image_to_column = ck::tensor_operation::host:: + ReferenceImageToColumn{}; + + auto ref_invoker = ref_image_to_column.MakeInvoker(); + auto ref_argument = ref_image_to_column.MakeArgument(input, + host_output, + conv_param.filter_spatial_lengths_, + conv_param.conv_filter_strides_, + conv_param.conv_filter_dilations_, + conv_param.input_left_pads_, + conv_param.input_right_pads_); + + // init host output to zero + host_output.SetZero(); + + ref_invoker.Run(ref_argument); + } + + using DeviceOp = ck::tensor_operation::device:: + DeviceImageToColumn; + + // get device op instances + const auto op_ptrs = ck::tensor_operation::device::instance::DeviceOperationInstanceFactory< + DeviceOp>::GetInstances(); + + std::cout << "found " << op_ptrs.size() << " instances" << std::endl; + + std::string best_op_name; + float best_avg_time = std::numeric_limits::max(); + float best_gb_per_sec = 0; + + // profile device op instances + bool pass = true; + bool is_supporting_instance = false; + + for(auto& op_ptr : op_ptrs) + { + auto argument_ptr = op_ptr->MakeArgumentPointer( + static_cast(in_device_buf.GetDeviceBuffer()), + static_cast(out_device_buf.GetDeviceBuffer()), + conv_param.N_, + conv_param.C_, + input_spatial_lengths, + filter_spatial_lengths, + output_spatial_lengths, + input_g_n_c_wis_strides, + output_m_k_strides, + conv_filter_strides, + conv_filter_dilations, + input_left_pads, + input_right_pads); + + if(op_ptr->IsSupportedArgument(argument_ptr.get())) + { + is_supporting_instance = true; + // re-init output to zero before profiling next kernel + out_device_buf.SetZero(); + std::string op_name = op_ptr->GetTypeString(); + auto invoker_ptr = op_ptr->MakeInvokerPointer(); + float avg_time = + invoker_ptr->Run(argument_ptr.get(), StreamConfig{nullptr, time_kernel}); + std::size_t num_btype = + NDoHoWo * CZYX * (sizeof(OutputDataType) + sizeof(InputDataType)); + float gb_per_sec = num_btype / 1.E6 / avg_time; + std::cout << "Perf: " << std::setw(10) << avg_time << " ms, " << gb_per_sec << " GB/s, " + << op_name << std::endl; + + if(avg_time < best_avg_time) + { + best_op_name = op_name; + best_avg_time = avg_time; + best_gb_per_sec = gb_per_sec; + } + + if(do_verification) + { + out_device_buf.FromDevice(device_output.mData.data()); + pass = pass & ck::utils::check_err(device_output, host_output); + + if(do_log) + { + LogRangeAsType(std::cout << "input : ", input.mData, ",") << std::endl; + LogRangeAsType(std::cout << "host_output : ", host_output.mData, ",") + << std::endl; + LogRangeAsType(std::cout << "device_output: ", device_output.mData, ",") + << std::endl; + } + } + } + else + { + std::cout << op_ptr->GetTypeString() << " does not support this problem" << std::endl; + } + } + + std::cout << "Best configuration parameters:" + << "\nname: " << best_op_name << "\navg_time: " << best_avg_time + << "\nGB/s: " << best_gb_per_sec << std::endl; + + return is_supporting_instance && pass; +} + +} // namespace profiler +} // namespace ck diff --git a/profiler/include/profiler/profile_layernorm_impl.hpp b/profiler/include/profiler/profile_layernorm_impl.hpp index 7dd90d079775ffe2da37d4e63e8391e1ee3709d0..f969646c2f68da7b502b4a64b08f87ef9e004304 100644 --- a/profiler/include/profiler/profile_layernorm_impl.hpp +++ b/profiler/include/profiler/profile_layernorm_impl.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #pragma once @@ -155,6 +155,10 @@ bool profile_layernorm_impl(int do_verification, continue; } + size_t workspace_sz = inst_ptr->GetWorkSpaceSize(argument_ptr.get()); + DeviceMem workspace_dev(workspace_sz); + inst_ptr->SetWorkSpacePointer(argument_ptr.get(), workspace_dev.GetDeviceBuffer()); + auto invoker_ptr = inst_ptr->MakeInvokerPointer(); float avg_time = invoker_ptr->Run(argument_ptr.get(), StreamConfig{nullptr, time_kernel}); diff --git a/profiler/include/profiler/profile_max_pool3d_bwd_impl.hpp b/profiler/include/profiler/profile_max_pool3d_bwd_impl.hpp new file mode 100644 index 0000000000000000000000000000000000000000..15fb4e90347310b2fe6d63d08b0c83c0a1533d1c --- /dev/null +++ b/profiler/include/profiler/profile_max_pool3d_bwd_impl.hpp @@ -0,0 +1,288 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include + +#include "ck/ck.hpp" +#include "ck/library/tensor_operation_instance/gpu/pool3d_fwd.hpp" +#include "ck/library/tensor_operation_instance/gpu/max_pool_bwd.hpp" +#include "ck/library/utility/check_err.hpp" +#include "ck/library/utility/device_memory.hpp" +#include "ck/library/utility/host_tensor.hpp" +#include "ck/library/utility/host_tensor_generator.hpp" +#include "ck/library/utility/literals.hpp" +#include "ck/library/reference_tensor_operation/cpu/reference_pool_fwd.hpp" +#include "ck/library/reference_tensor_operation/cpu/reference_maxpool_bwd.hpp" + +namespace ck { +namespace profiler { + +template +bool profile_max_pool3d_bwd_impl(int do_verification, + int init_method, + bool do_log, + bool time_kernel, + std::vector in_length, // NCDHW + std::vector window_spatial_lengths, + std::vector window_strides, + std::vector window_dilations, + std::vector input_left_pads, + std::vector input_right_pads) +{ + // AtomicAdd only support f32 for now. ComputeDataType must be float32 + using ComputeDataType = float; + + constexpr index_t InOutRank = 5; + constexpr index_t WindowRank = 3; + + using PassThrough = ck::tensor_operation::element_wise::PassThrough; + + if(in_length.size() != InOutRank || window_spatial_lengths.size() != WindowRank || + window_strides.size() != WindowRank || window_dilations.size() != WindowRank || + input_left_pads.size() != WindowRank || input_right_pads.size() != WindowRank) + { + std::cout << "Parameter is incorrect" << std::endl; + return false; + } + + std::vector out_length(InOutRank); + + int N = in_length[0]; + int C = in_length[1]; + + out_length[0] = N; + out_length[1] = C; + + // Calculate Do, Ho, Wo + for(int i = 2; i < InOutRank; ++i) + { + auto pad1 = input_left_pads[i - 2]; + auto pad2 = input_right_pads[i - 2]; + auto windows_size = window_spatial_lengths[i - 2]; + auto windows_stride = window_strides[i - 2]; + auto windows_dilation = window_dilations[i - 2]; + auto eff = (windows_size - 1) * windows_dilation + 1; + out_length[i] = (in_length[i] + pad1 + pad2 - eff) / windows_stride + 1; + } + + int Di = in_length[2]; + int Hi = in_length[3]; + int Wi = in_length[4]; + int Do = out_length[2]; + int Ho = out_length[3]; + int Wo = out_length[4]; + + auto f_host_tensor_descriptor = + [](std::size_t N_, std::size_t C_, std::size_t D, std::size_t H, std::size_t W) { + using namespace ck::literals; + + return HostTensorDescriptor({N_, C_, D, H, W}, + {D * C_ * H * W, 1_uz, C_ * H * W, W * C_, C_}); + }; + + Tensor in_n_c_di_hi_wi(f_host_tensor_descriptor(N, C, Di, Hi, Wi)); + Tensor out_n_c_do_ho_wo(f_host_tensor_descriptor(N, C, Do, Ho, Wo)); + Tensor out_indices_n_c_do_ho_wo(f_host_tensor_descriptor(N, C, Do, Ho, Wo)); + Tensor dout_n_c_do_ho_wo(f_host_tensor_descriptor(N, C, Do, Ho, Wo)); + Tensor din_n_c_di_hi_wi_host(f_host_tensor_descriptor(N, C, Di, Hi, Wi)); + + Tensor din_n_c_di_hi_wi_device(f_host_tensor_descriptor(N, C, Di, Hi, Wi)); + + switch(init_method) + { + case 0: + in_n_c_di_hi_wi.GenerateTensorValue(GeneratorTensor_1{}); + dout_n_c_do_ho_wo.GenerateTensorValue(GeneratorTensor_1{}); + break; + case 1: + in_n_c_di_hi_wi.GenerateTensorValue(GeneratorTensor_2{-5, 5}); + dout_n_c_do_ho_wo.GenerateTensorValue(GeneratorTensor_2{-5, 5}); + break; + default: + in_n_c_di_hi_wi.GenerateTensorValue(GeneratorTensor_3{-0.5, 0.5}); + dout_n_c_do_ho_wo.GenerateTensorValue(GeneratorTensor_3{-0.5, 0.5}); + } + + DeviceMem indices_device_buf(sizeof(IndexDataType) * + out_indices_n_c_do_ho_wo.mDesc.GetElementSpaceSize()); + DeviceMem dout_device_buf(sizeof(DOutDataType) * dout_n_c_do_ho_wo.mDesc.GetElementSpaceSize()); + DeviceMem din_device_buf(sizeof(DInDataType) * + din_n_c_di_hi_wi_device.mDesc.GetElementSpaceSize()); + + // Generate index data from forwarding + { + using ReferencePoolingFwdInstance = + ck::tensor_operation::host::ReferencePoolingFwd; + + ReferencePoolingFwdInstance ref_pooling_fwd; + auto ref_pooling_fwd_argument = ref_pooling_fwd.MakeArgument(in_n_c_di_hi_wi, + out_n_c_do_ho_wo, + out_indices_n_c_do_ho_wo, + window_spatial_lengths, + window_strides, + window_dilations, + input_left_pads, + input_right_pads); + auto ref_pooling_fwd_invoker = ref_pooling_fwd.MakeInvoker(); + ref_pooling_fwd_invoker.Run(ref_pooling_fwd_argument); + } + + indices_device_buf.ToDevice(out_indices_n_c_do_ho_wo.mData.data()); + dout_device_buf.ToDevice(dout_n_c_do_ho_wo.mData.data()); + + using DeviceOp = + ck::tensor_operation::device::DeviceMaxPoolBwd; + + // get device op instances + const auto instance_ptrs = + ck::tensor_operation::device::instance::DeviceOperationInstanceFactory< + DeviceOp>::GetInstances(); + + std::cout << "found " << instance_ptrs.size() << " instances" << std::endl; + + std::string best_instance_name; + float best_avg_time = std::numeric_limits::max(); + float best_gb_per_sec = 0; + + if(do_verification) + { + using ReferencePoolingBwdInstance = + ck::tensor_operation::host::ReferenceMaxPoolBwd; + + ReferencePoolingBwdInstance ref_pooling_bwd; + auto ref_pooling_bwd_argument = ref_pooling_bwd.MakeArgument( + dout_n_c_do_ho_wo, out_indices_n_c_do_ho_wo, din_n_c_di_hi_wi_host, PassThrough{}); + auto ref_invoker = ref_pooling_bwd.MakeInvoker(); + ref_invoker.Run(ref_pooling_bwd_argument); + } + + int num_kernel = 0; + + for(auto& inst_ptr : instance_ptrs) + { + auto argument_ptr = inst_ptr->MakeArgumentPointer( + static_cast(dout_device_buf.GetDeviceBuffer()), + static_cast(indices_device_buf.GetDeviceBuffer()), + static_cast(din_device_buf.GetDeviceBuffer()), + dout_n_c_do_ho_wo.mDesc.GetElementSpaceSize(), + din_n_c_di_hi_wi_device.mDesc.GetElementSpaceSize(), + window_spatial_lengths, + window_strides, + window_dilations); + + if(inst_ptr->IsSupportedArgument(argument_ptr.get())) + { + ++num_kernel; + } + else + { + if(time_kernel) + { + std::cout << inst_ptr->GetTypeString() << " skipped due to unsupported argument: "; + LogRange(std::cout << "doutput lengths = ", out_length, ", ") << std::endl; + } + + continue; + } + + size_t workspace_sz = inst_ptr->GetWorkSpaceSize(argument_ptr.get()); + DeviceMem workspace_device_buf(workspace_sz); + inst_ptr->SetWorkSpacePointer(argument_ptr.get(), workspace_device_buf.GetDeviceBuffer()); + + auto invoker_ptr = inst_ptr->MakeInvokerPointer(); + float avg_time = invoker_ptr->Run(argument_ptr.get(), StreamConfig{nullptr, time_kernel}); + + std::size_t num_bytes = + dout_n_c_do_ho_wo.mDesc.GetElementSize() * sizeof(DOutDataType) + + out_indices_n_c_do_ho_wo.mDesc.GetElementSize() * sizeof(IndexDataType) + + din_n_c_di_hi_wi_device.mDesc.GetElementSize() * sizeof(DInDataType); + + float gb_per_sec = num_bytes / 1.E6 / avg_time; + + if(time_kernel) + std::cout << "Perf: " << std::setw(10) << avg_time << " ms, " << gb_per_sec << " GB/s, " + << inst_ptr->GetTypeString() << std::endl; + + if(avg_time < best_avg_time) + { + best_instance_name = inst_ptr->GetTypeString(); + best_avg_time = avg_time; + best_gb_per_sec = gb_per_sec; + } + + if(do_verification) + { + din_device_buf.FromDevice(din_n_c_di_hi_wi_device.mData.data()); + + bool pass = ck::utils::check_err(din_n_c_di_hi_wi_device.mData, + din_n_c_di_hi_wi_host.mData, + "Error: Incorrect results", + 1e-3, + 1e-3); + + if(do_log) + { + LogRangeAsType( + std::cout << "out_indices_n_c_do_ho_wo: ", out_indices_n_c_do_ho_wo.mData, ",") + << std::endl; + + LogRangeAsType( + std::cout << "din_n_c_di_hi_wi_device: ", din_n_c_di_hi_wi_device.mData, ",") + << std::endl; + + LogRangeAsType( + std::cout << "din_n_c_di_hi_wi_host: ", din_n_c_di_hi_wi_host.mData, ",") + << std::endl; + } + + if(!pass) + { + std::cout << inst_ptr->GetTypeString() << " failed verification: "; + LogRange(std::cout << "doutput lengths = [", out_length, ", ") << "]." << std::endl; + return false; + } + else + { + if(time_kernel) + std::cout << "pass" << std::endl; + } + } + } + + if(time_kernel) + { + LogRange(std::cout << "length = ", out_length, ",") << std::endl; + std::cout << "best perf = " << best_avg_time << " ms, " << best_gb_per_sec << " GB/s, " + << best_instance_name << std::endl; + } + + if(num_kernel == 0) + { + std::cout << "Error: No kernel is applicable" << std::endl; + return false; + } + + return true; +} + +} // namespace profiler +} // namespace ck diff --git a/profiler/include/profiler/profile_pool2d_fwd_impl.hpp b/profiler/include/profiler/profile_pool2d_fwd_impl.hpp deleted file mode 100644 index c313a00bebf0a35d82262f5c7b1e24f5f2c7c6ce..0000000000000000000000000000000000000000 --- a/profiler/include/profiler/profile_pool2d_fwd_impl.hpp +++ /dev/null @@ -1,264 +0,0 @@ -// SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. - -#pragma once - -#include - -#include "ck/ck.hpp" -#include "ck/library/tensor_operation_instance/gpu/pool2d_fwd.hpp" -#include "ck/library/utility/check_err.hpp" -#include "ck/library/utility/device_memory.hpp" -#include "ck/library/utility/host_tensor.hpp" -#include "ck/library/utility/host_tensor_generator.hpp" -#include "ck/library/utility/literals.hpp" -#include "ck/library/reference_tensor_operation/cpu/reference_pool_fwd.hpp" - -namespace ck { -namespace profiler { - -template -bool profile_pool2d_fwd_impl(int do_verification, - int init_method, - bool do_log, - bool time_kernel, - std::vector in_length, // NCHW - std::vector window_spatial_lengths, - std::vector window_strides, - std::vector input_left_pads, - std::vector input_right_pads) -{ - constexpr index_t InOutRank = 4; - constexpr index_t WindowRank = 2; - - if(in_length.size() != InOutRank || window_spatial_lengths.size() != WindowRank || - window_strides.size() != WindowRank || input_left_pads.size() != WindowRank || - input_right_pads.size() != WindowRank) - return false; - - std::vector out_length(InOutRank); - - int N = in_length[0]; - int C = in_length[1]; - - out_length[0] = N; - out_length[1] = C; - - // Calculate Ho, Wo - for(int i = 2; i < InOutRank; ++i) - { - auto pad1 = input_left_pads[i - 2]; - auto pad2 = input_right_pads[i - 2]; - auto windows_size = window_spatial_lengths[i - 2]; - auto windows_stride = window_strides[i - 2]; - out_length[i] = (in_length[i] + pad1 + pad2 - windows_size) / windows_stride + 1; - } - - int Hi = in_length[2]; - int Wi = in_length[3]; - int Ho = out_length[2]; - int Wo = out_length[3]; - - auto f_host_tensor_descriptor = - [](std::size_t N_, std::size_t C_, std::size_t H, std::size_t W) { - using namespace ck::literals; - return HostTensorDescriptor({N_, C_, H, W}, {C_ * H * W, 1_uz, W * C_, C_}); - }; - - Tensor in_n_c_hi_wi(f_host_tensor_descriptor(N, C, Hi, Wi)); - Tensor out_n_c_ho_wo_host(f_host_tensor_descriptor(N, C, Ho, Wo)); - Tensor out_indices_n_c_ho_wo_host(f_host_tensor_descriptor(N, C, Ho, Wo)); - - Tensor out_n_c_ho_wo_device(f_host_tensor_descriptor(N, C, Ho, Wo)); - Tensor out_indices_n_c_ho_wo_device(f_host_tensor_descriptor(N, C, Ho, Wo)); - - switch(init_method) - { - case 0: in_n_c_hi_wi.GenerateTensorValue(GeneratorTensor_1{}); break; - case 1: in_n_c_hi_wi.GenerateTensorValue(GeneratorTensor_2{-5, 5}); break; - default: in_n_c_hi_wi.GenerateTensorValue(GeneratorTensor_3{-0.5, 0.5}); - } - - DeviceMem in_device_buf(sizeof(InDataType) * in_n_c_hi_wi.mDesc.GetElementSpaceSize()); - DeviceMem out_device_buf(sizeof(OutDataType) * - out_n_c_ho_wo_device.mDesc.GetElementSpaceSize()); - DeviceMem out_indices_device_buf(sizeof(IndexDataType) * - out_indices_n_c_ho_wo_device.mDesc.GetElementSpaceSize()); - - in_device_buf.ToDevice(in_n_c_hi_wi.mData.data()); - - // add device normalization instances - using DeviceOp = ck::tensor_operation::device::DevicePoolFwd; - - // get device op instances - const auto instance_ptrs = - ck::tensor_operation::device::instance::DeviceOperationInstanceFactory< - DeviceOp>::GetInstances(); - - std::cout << "found " << instance_ptrs.size() << " instances" << std::endl; - - std::string best_instance_name; - float best_avg_time = std::numeric_limits::max(); - float best_gb_per_sec = 0; - - if(do_verification) - { - using ReferenceInstance = ck::tensor_operation::host::ReferencePoolingFwd; - - ReferenceInstance ref; - auto ref_argument = ref.MakeArgument(in_n_c_hi_wi, - out_n_c_ho_wo_host, - out_indices_n_c_ho_wo_host, - window_spatial_lengths, - window_strides, - input_left_pads, - input_right_pads); - auto ref_invoker = ref.MakeInvoker(); - ref_invoker.Run(ref_argument); - } - - int num_kernel = 0; - - for(auto& inst_ptr : instance_ptrs) - { - auto argument_ptr = inst_ptr->MakeArgumentPointer( - static_cast(in_device_buf.GetDeviceBuffer()), - static_cast(out_device_buf.GetDeviceBuffer()), - static_cast(out_indices_device_buf.GetDeviceBuffer()), - in_length, - window_spatial_lengths, - out_length, - {C * Hi * Wi, 1, Wi * C, C}, - {C * Ho * Wo, 1, Wo * C, C}, - {C * Ho * Wo, 1, Wo * C, C}, - window_strides, - input_left_pads, - input_right_pads, - {2, 3}); - - if(inst_ptr->IsSupportedArgument(argument_ptr.get())) - { - ++num_kernel; - } - else - { - if(time_kernel) - { - std::cout << inst_ptr->GetTypeString() << " skipped due to unsupported argument: "; - LogRange(std::cout << "input lengths = ", in_length, ", ") << std::endl; - } - - continue; - } - - auto invoker_ptr = inst_ptr->MakeInvokerPointer(); - - float avg_time = invoker_ptr->Run(argument_ptr.get(), StreamConfig{nullptr, time_kernel}); - - std::size_t num_bytes = in_n_c_hi_wi.mDesc.GetElementSize() * sizeof(InDataType) + - out_n_c_ho_wo_host.mDesc.GetElementSize() * sizeof(OutDataType); - - if constexpr(OutputIndex) - num_bytes += out_indices_n_c_ho_wo_host.mDesc.GetElementSize() * sizeof(IndexDataType); - - float gb_per_sec = num_bytes / 1.E6 / avg_time; - - if(time_kernel) - std::cout << "Perf: " << std::setw(10) << avg_time << " ms, " << gb_per_sec << " GB/s, " - << inst_ptr->GetTypeString() << std::endl; - - if(avg_time < best_avg_time) - { - best_instance_name = inst_ptr->GetTypeString(); - best_avg_time = avg_time; - best_gb_per_sec = gb_per_sec; - } - - if(do_verification) - { - out_device_buf.FromDevice(out_n_c_ho_wo_device.mData.data()); - - bool pass = ck::utils::check_err(out_n_c_ho_wo_device.mData, - out_n_c_ho_wo_host.mData, - "Error: Incorrect results", - 1e-3, - 1e-3); - - if constexpr(OutputIndex) - { - out_indices_device_buf.FromDevice(out_indices_n_c_ho_wo_device.mData.data()); - - pass = pass && ck::utils::check_err(out_indices_n_c_ho_wo_device, - out_indices_n_c_ho_wo_host); - } - - if(do_log) - { - LogRangeAsType(std::cout << "in_n_c_hi_wi : ", in_n_c_hi_wi.mData, ",") - << std::endl; - LogRangeAsType( - std::cout << "out_n_c_ho_wo_host : ", out_n_c_ho_wo_host.mData, ",") - << std::endl; - LogRangeAsType( - std::cout << "out_n_c_ho_wo_device : ", out_n_c_ho_wo_device.mData, ",") - << std::endl; - - if constexpr(OutputIndex) - LogRangeAsType(std::cout << "out_indices_n_c_ho_wo_device : ", - out_indices_n_c_ho_wo_device.mData, - ",") - << std::endl; - } - - if(!pass) - { - std::cout << inst_ptr->GetTypeString() << " failed verification: "; - LogRange(std::cout << "lengths = [", in_length, ", ") << "]." << std::endl; - return false; - } - else - { - if(time_kernel) - std::cout << "pass" << std::endl; - } - } - } - - if(time_kernel) - { - LogRange(std::cout << "length = ", in_length, ",") << std::endl; - std::cout << "best perf = " << best_avg_time << " ms, " << best_gb_per_sec << " GB/s, " - << best_instance_name << std::endl; - } - - if(num_kernel == 0) - { - std::cout << "Error: No kernel is applicable" << std::endl; - return false; - } - - return true; -} - -} // namespace profiler -} // namespace ck diff --git a/profiler/include/profiler/profile_pool3d_fwd_impl.hpp b/profiler/include/profiler/profile_pool3d_fwd_impl.hpp index c9e4c193f89221d99b1119bcf5431a60d899dd23..02fde48d6e6bdb2e4c528801a3f32afcba7133e9 100644 --- a/profiler/include/profiler/profile_pool3d_fwd_impl.hpp +++ b/profiler/include/profiler/profile_pool3d_fwd_impl.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #pragma once @@ -21,6 +21,8 @@ template @@ -31,6 +33,7 @@ bool profile_pool3d_fwd_impl(int do_verification, std::vector in_length, // NCDHW std::vector window_spatial_lengths, std::vector window_strides, + std::vector window_dilations, std::vector input_left_pads, std::vector input_right_pads) { @@ -38,8 +41,8 @@ bool profile_pool3d_fwd_impl(int do_verification, constexpr index_t WindowRank = 3; if(in_length.size() != InOutRank || window_spatial_lengths.size() != WindowRank || - window_strides.size() != WindowRank || input_left_pads.size() != WindowRank || - input_right_pads.size() != WindowRank) + window_strides.size() != WindowRank || window_dilations.size() != WindowRank || + input_left_pads.size() != WindowRank || input_right_pads.size() != WindowRank) return false; std::vector out_length(InOutRank); @@ -53,11 +56,13 @@ bool profile_pool3d_fwd_impl(int do_verification, // Calculate Do, Ho, Wo for(int i = 2; i < InOutRank; ++i) { - auto pad1 = input_left_pads[i - 2]; - auto pad2 = input_right_pads[i - 2]; - auto windows_size = window_spatial_lengths[i - 2]; - auto windows_stride = window_strides[i - 2]; - out_length[i] = (in_length[i] + pad1 + pad2 - windows_size) / windows_stride + 1; + auto pad1 = input_left_pads[i - 2]; + auto pad2 = input_right_pads[i - 2]; + auto windows_size = window_spatial_lengths[i - 2]; + auto windows_stride = window_strides[i - 2]; + auto windows_dilation = window_dilations[i - 2]; + auto eff = (windows_size - 1) * windows_dilation + 1; + out_length[i] = (in_length[i] + pad1 + pad2 - eff) / windows_stride + 1; } int Di = in_length[2]; @@ -104,6 +109,8 @@ bool profile_pool3d_fwd_impl(int do_verification, InDataType, OutDataType, IndexDataType, + InLayout, + OutLayout, ReduceOpId, OutputIndex>; @@ -136,6 +143,7 @@ bool profile_pool3d_fwd_impl(int do_verification, out_indices_n_c_do_ho_wo_host, window_spatial_lengths, window_strides, + window_dilations, input_left_pads, input_right_pads); auto ref_invoker = ref.MakeInvoker(); @@ -157,6 +165,7 @@ bool profile_pool3d_fwd_impl(int do_verification, {Do * C * Ho * Wo, 1, C * Ho * Wo, Wo * C, C}, {Do * C * Ho * Wo, 1, C * Ho * Wo, Wo * C, C}, window_strides, + window_dilations, input_left_pads, input_right_pads, {2, 3, 4}); diff --git a/profiler/include/profiler/profile_reduce_impl.hpp b/profiler/include/profiler/profile_reduce_impl.hpp index e6182002999b8c9b0838dab70d387d92d59e6cdc..b54aa65aef727b048458e0b63f638bad143e4e28 100644 --- a/profiler/include/profiler/profile_reduce_impl.hpp +++ b/profiler/include/profiler/profile_reduce_impl.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/profiler/include/profiler/profile_softmax_impl.hpp b/profiler/include/profiler/profile_softmax_impl.hpp index 96816f53bbb7da01435569cb7aa6c8e2ea099586..daaf565149784729fc66c4c6ef048df47c0c8dae 100644 --- a/profiler/include/profiler/profile_softmax_impl.hpp +++ b/profiler/include/profiler/profile_softmax_impl.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #pragma once @@ -40,7 +40,11 @@ template <> std::string type_to_string() { return "int8"; } template <> std::string type_to_string() { return "int32"; } // clang-format on -template +template bool profile_softmax_impl(int do_verification, int init_method, bool do_log, @@ -54,7 +58,13 @@ bool profile_softmax_impl(int do_verification, if(Rank != in_length.size()) { throw std::runtime_error("Input tensor rank is different from template argument Rank!"); - } + }; + + if(NumReduceDim != reduce_dims.size()) + { + throw std::runtime_error( + "Input reduce_dims rank is different from template argument NumReduceDim!"); + }; Tensor in = in_strides.empty() ? Tensor(in_length) : Tensor(in_length, in_strides); @@ -92,8 +102,13 @@ bool profile_softmax_impl(int do_verification, // add device softmax instances using PassThrough = ck::tensor_operation::element_wise::PassThrough; - using DeviceOp = tensor_operation::device:: - DeviceSoftmax; + using DeviceOp = tensor_operation::device::DeviceSoftmax; // get device op instances const auto instances = tensor_operation::device::instance::DeviceOperationInstanceFactory< @@ -112,13 +127,6 @@ bool profile_softmax_impl(int do_verification, for(auto& inst_ptr : instances) { - // Is this user's responsibility to check if problem mismatches kernel instance (ie. rank 3 - // problem to rank 4 kernel) other than invoking IsSupportedArgument()? - if(!(inst_ptr->GetNumReduceDim() == static_cast(reduce_dims.size()))) - { - continue; - } - auto argument_ptr = inst_ptr->MakeArgumentPointer(in_tensor_lengths, in_tensor_strides, reduce_dims, diff --git a/profiler/src/CMakeLists.txt b/profiler/src/CMakeLists.txt index c9fccc2588222bd640ab646fa3a691cc3b0c1ead..7da7613f2628a7afa7d7454457d1a545f1c8181a 100644 --- a/profiler/src/CMakeLists.txt +++ b/profiler/src/CMakeLists.txt @@ -3,19 +3,12 @@ set(PROFILER_SOURCES profiler.cpp profile_gemm.cpp profile_gemm_splitk.cpp - profile_gemm_bilinear.cpp profile_gemm_bias_add_reduce.cpp - profile_gemm_add_add_fastgelu.cpp profile_gemm_add_multiply.cpp - profile_gemm_add_fastgelu.cpp - profile_gemm_add_relu_add_layernorm.cpp - profile_gemm_fastgelu.cpp + profile_gemm_multiply_add.cpp profile_gemm_reduce.cpp profile_batched_gemm.cpp - profile_batched_gemm_gemm.cpp - profile_batched_gemm_add_relu_gemm_add.cpp profile_batched_gemm_reduce.cpp - profile_grouped_gemm.cpp profile_conv_fwd.cpp profile_conv_fwd_bias_relu.cpp profile_conv_fwd_bias_relu_add.cpp @@ -25,16 +18,33 @@ set(PROFILER_SOURCES profile_reduce.cpp profile_groupnorm.cpp profile_layernorm.cpp - profile_avg_pool2d_fwd.cpp profile_max_pool3d_fwd.cpp + profile_avg_pool3d_bwd.cpp + profile_max_pool3d_bwd.cpp profile_softmax.cpp profile_batchnorm_fwd.cpp profile_batchnorm_bwd.cpp profile_batchnorm_infer.cpp - profile_grouped_gemm_fastgelu.cpp profile_contraction_bilinear.cpp profile_contraction_scale.cpp + profile_grouped_conv_bwd_data.cpp + profile_image_to_column.cpp ) +if(DL_KERNELS) + list(APPEND PROFILER_SOURCES profile_batched_gemm_multi_d.cpp) +endif() +if(DTYPES MATCHES "fp16" OR NOT DEFINED DTYPES) + list(APPEND PROFILER_SOURCES profile_batched_gemm_gemm.cpp) + list(APPEND PROFILER_SOURCES profile_gemm_fastgelu.cpp) + list(APPEND PROFILER_SOURCES profile_gemm_streamk.cpp) + list(APPEND PROFILER_SOURCES profile_gemm_bilinear.cpp) + list(APPEND PROFILER_SOURCES profile_gemm_add_fastgelu.cpp) + list(APPEND PROFILER_SOURCES profile_gemm_add_add_fastgelu.cpp) + list(APPEND PROFILER_SOURCES profile_gemm_add_relu_add_layernorm.cpp) + list(APPEND PROFILER_SOURCES profile_batched_gemm_add_relu_gemm_add.cpp) + list(APPEND PROFILER_SOURCES profile_grouped_gemm.cpp) + list(APPEND PROFILER_SOURCES profile_grouped_gemm_fastgelu.cpp) +endif() set(PROFILER_EXECUTABLE ckProfiler) @@ -44,19 +54,12 @@ target_compile_options(${PROFILER_EXECUTABLE} PRIVATE -Wno-global-constructors) target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE utility) target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_gemm_instance) target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_gemm_splitk_instance) -target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_gemm_bilinear_instance) -target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_gemm_add_add_fastgelu_instance) target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_gemm_add_multiply_instance) -target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_gemm_add_fastgelu_instance) -target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_gemm_fastgelu_instance) -target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_gemm_add_relu_add_layernorm_instance) +target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_gemm_multiply_add_instance) target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_gemm_reduce_instance) target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_gemm_bias_add_reduce_instance) target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_batched_gemm_instance) -target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_batched_gemm_gemm_instance) -target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_batched_gemm_add_relu_gemm_add_instance) target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_batched_gemm_reduce_instance) -target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_grouped_gemm_instance) target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_conv2d_fwd_instance) target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_grouped_conv1d_fwd_instance) target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_grouped_conv2d_fwd_instance) @@ -73,9 +76,27 @@ target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_normalization_instan target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_softmax_instance) target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_reduce_instance) target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_batchnorm_instance) -target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_grouped_gemm_fastgelu_instance) target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_contraction_bilinear_instance) target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_contraction_scale_instance) -target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_pool_fwd_instance) - +target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_pool3d_fwd_instance) +target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_avg_pool3d_bwd_instance) +target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_max_pool_bwd_instance) +target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_grouped_conv2d_bwd_data_instance) +target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_grouped_conv3d_bwd_data_instance) +target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_image_to_column_instance) +if(DL_KERNELS) + target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_batched_gemm_multi_d_instance) +endif() +if(DTYPES MATCHES "fp16" OR NOT DEFINED DTYPES) + target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_gemm_add_fastgelu_instance) + target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_gemm_add_relu_add_layernorm_instance) + target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_gemm_bilinear_instance) + target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_gemm_add_add_fastgelu_instance) + target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_gemm_streamk_instance) + target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_gemm_fastgelu_instance) + target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_batched_gemm_gemm_instance) + target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_batched_gemm_add_relu_gemm_add_instance) + target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_grouped_gemm_instance) + target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_grouped_gemm_fastgelu_instance) +endif() rocm_install(TARGETS ${PROFILER_EXECUTABLE} COMPONENT profiler) diff --git a/profiler/src/profile_avg_pool2d_fwd.cpp b/profiler/src/profile_avg_pool2d_fwd.cpp deleted file mode 100644 index b92288096d83eda37c75efe1cdaceb1603ef6cb7..0000000000000000000000000000000000000000 --- a/profiler/src/profile_avg_pool2d_fwd.cpp +++ /dev/null @@ -1,141 +0,0 @@ -// SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. - -#include -#include -#include - -#include "profiler/data_type_enum.hpp" -#include "profiler/profile_pool2d_fwd_impl.hpp" -#include "profiler_operation_registry.hpp" - -using ck::index_t; - -struct avgPoolFwdArgParser -{ - std::unordered_map> long_opts = { - {"length", {}}, {"wsize", {}}, {"wstride", {}}, {"pad1", {}}, {"pad2", {}}}; - - bool parse_opt(int argc, char* argv[], const std::string& key, int i) - { - if(std::string("--") + key == argv[i]) - { - int pos = i; - while(++i < argc && argv[i][0] != '-') {} - int end = i; - for(int j = pos + 1; j < end; j++) - { - long_opts[key].push_back(std::stoi(argv[j])); - } - return true; - } - return false; - } - - void operator()(int argc, char* argv[]) - { - for(auto& kv : long_opts) - { - for(int i = 1; i < argc; i++) - { - if(parse_opt(argc, argv, kv.first, i)) - break; - } - } - } -}; - -void print_help_avg_pool2d_fwd() -{ - std::cout << "arg1: data type (0: fp16; 1: fp32)\n" - << "arg2: verification (0: no; 1: yes)\n" - << "arg3: initialization (0: no init; 1: integer value; 2: decimal value)\n" - << "arg4: print tensor value (0: no; 1: yes)\n" - << "arg5: time kernel (0=no, 1=yes)\n" - << "--length: input tensor length for NDHW(e.g, --length 2 32 30 30) \n" - << "--wsize: window size for YX (e.g, --wsize 2 2) \n" - << "--wstride: window stride for HW (e.g, --wstride 2 2) \n" - << "--pad1: left side of padding in HW (e.g, --pad1 1 1) \n" - << "--pad2: right side of padding in HW (e.g, --pad2 1 1) \n" - << "eg: ckProfiler avg_pool2d_fwd 0 1 2 0 1 0 --length 2 32 30 30 --wsize 2 2 " - "--wstride 2 2 --pad1 1 1 --pad2 1 1" - << std::endl; -} - -int profile_avg_pool2d_fwd(int argc, char* argv[]) -{ - ck::DataTypeEnum data_type = ck::DataTypeEnum::Half; - bool do_verification = true; - int init_method = 0; - bool do_log = false; - bool time_kernel = true; - - std::vector in_length = {2, 32, 30, 30}; - std::vector wsize = {2, 2}; - std::vector wstride = {2, 2}; - std::vector pad1 = {1, 1}; - std::vector pad2 = {1, 1}; - - if(argc != 2 && argc != 25) - { - print_help_avg_pool2d_fwd(); - return 0; - } - else if(argc == 25) - { - data_type = static_cast(std::stoi(argv[2])); - do_verification = std::stoi(argv[3]); - init_method = std::stoi(argv[4]); - do_log = std::stoi(argv[5]); - time_kernel = std::stoi(argv[6]); - - // parse the long options - avgPoolFwdArgParser arg_parser; - arg_parser(argc, argv); - in_length = arg_parser.long_opts["length"]; - wsize = arg_parser.long_opts["wsize"]; - wstride = arg_parser.long_opts["wstride"]; - pad1 = arg_parser.long_opts["pad1"]; - pad2 = arg_parser.long_opts["pad2"]; - } - - using F16 = ck::half_t; - using F32 = float; - using I32 = int32_t; - constexpr auto ReduceOpId = ck::ReduceTensorOp::AVG; - - if(data_type == ck::DataTypeEnum::Half) - { - ck::profiler::profile_pool2d_fwd_impl( - do_verification, - init_method, - do_log, - time_kernel, - in_length, - wsize, - wstride, - pad1, - pad2); - } - else if(data_type == ck::DataTypeEnum::Float) - { - ck::profiler::profile_pool2d_fwd_impl( - do_verification, - init_method, - do_log, - time_kernel, - in_length, - wsize, - wstride, - pad1, - pad2); - } - else - { - throw std::runtime_error("not implemented yet"); - } - - return 0; -} - -REGISTER_PROFILER_OPERATION("avg_pool2d_fwd", "avg_pool2d fwd", profile_avg_pool2d_fwd); diff --git a/profiler/src/profile_avg_pool3d_bwd.cpp b/profiler/src/profile_avg_pool3d_bwd.cpp new file mode 100644 index 0000000000000000000000000000000000000000..0ff50a5292fc1856d28406310e802524e6607b81 --- /dev/null +++ b/profiler/src/profile_avg_pool3d_bwd.cpp @@ -0,0 +1,175 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#include +#include +#include + +#include "profiler/data_type_enum.hpp" +#include "profiler/profile_avg_pool3d_bwd_impl.hpp" +#include "profiler_operation_registry.hpp" + +using ck::index_t; + +struct maxPoolbwdArgParser +{ + std::unordered_map> long_opts = {{"length", {}}, + {"wsize", {}}, + {"wstride", {}}, + {"wdilation", {}}, + {"pad1", {}}, + {"pad2", {}}}; + + bool parse_opt(int argc, char* argv[], const std::string& key, int i) + { + if(std::string("--") + key == argv[i]) + { + int pos = i; + while(++i < argc && argv[i][0] != '-') {} + int end = i; + for(int j = pos + 1; j < end; j++) + { + long_opts[key].push_back(std::stoi(argv[j])); + } + return true; + } + return false; + } + + void operator()(int argc, char* argv[]) + { + for(auto& kv : long_opts) + { + for(int i = 1; i < argc; i++) + { + if(parse_opt(argc, argv, kv.first, i)) + break; + } + } + } +}; + +void print_help_avg_pool3d_bwd() +{ + std::cout << "arg1: data type (0: fp16; 1: fp32; 5: bf16)\n" + << "arg2: verification (0: no; 1: yes)\n" + << "arg3: initialization (0: no init; 1: integer value; 2: decimal value)\n" + << "arg4: print tensor value (0: no; 1: yes)\n" + << "arg5: time kernel (0=no, 1=yes)\n" + << "--length: input tensor length for NCDHW(e.g, --length 2 32 30 30 30) \n" + << "--wsize: window size for ZYX (e.g, --wsize 2 2 2) \n" + << "--wstride: window stride for DHW (e.g, --wstride 2 2 2) \n" + << "--wdilation: window dilation for DHW (e.g, --wdilation 1 1 1) \n" + << "--pad1: left side of padding in DHW (e.g, --pad1 1 1 1) \n" + << "--pad2: right side of padding in DHW (e.g, --pad2 1 1 1) \n" + << "eg: ckProfiler avg_pool3d_bwd 0 1 2 0 1 --length 2 32 30 30 30 --wsize 2 2 2 " + "--wstride 2 2 2 --wdilation 1 1 1 --pad1 1 1 1 --pad2 1 1 1" + << std::endl; +} + +int profile_avg_pool3d_bwd(int argc, char* argv[]) +{ + ck::DataTypeEnum data_type = ck::DataTypeEnum::Half; + bool do_verification = true; + int init_method = 0; + bool do_log = false; + bool time_kernel = true; + + std::vector in_length = {2, 32, 30, 30, 30}; + std::vector wsize = {2, 2, 2}; + std::vector wstride = {2, 2, 2}; + std::vector wdilation = {1, 1, 1}; + std::vector pad1 = {1, 1, 1}; + std::vector pad2 = {1, 1, 1}; + + if(argc != 2 && argc != 33) + { + print_help_avg_pool3d_bwd(); + return 0; + } + else if(argc == 33) + { + data_type = static_cast(std::stoi(argv[2])); + do_verification = std::stoi(argv[3]); + init_method = std::stoi(argv[4]); + do_log = std::stoi(argv[5]); + time_kernel = std::stoi(argv[6]); + + // parse the long options + maxPoolbwdArgParser arg_parser; + arg_parser(argc, argv); + in_length = arg_parser.long_opts["length"]; + wsize = arg_parser.long_opts["wsize"]; + wstride = arg_parser.long_opts["wstride"]; + wdilation = arg_parser.long_opts["wdilation"]; + pad1 = arg_parser.long_opts["pad1"]; + pad2 = arg_parser.long_opts["pad2"]; + } + +#ifdef CK_ENABLE_FP16 + using F16 = ck::half_t; +#endif +#ifdef CK_ENABLE_BF16 + using BF16 = ck::bhalf_t; +#endif +#ifdef CK_ENABLE_FP32 + using F32 = float; +#endif + using NDHWC = ck::tensor_layout::convolution::NDHWC; + + if(false) + ; +#ifdef CK_ENABLE_FP16 + else if(data_type == ck::DataTypeEnum::Half) + { + ck::profiler::profile_avg_pool3d_bwd_impl(do_verification, + init_method, + do_log, + time_kernel, + in_length, + wsize, + wstride, + wdilation, + pad1, + pad2); + } +#endif +#ifdef CK_ENABLE_BF16 + else if(data_type == ck::DataTypeEnum::BFloat16) + { + ck::profiler::profile_avg_pool3d_bwd_impl(do_verification, + init_method, + do_log, + time_kernel, + in_length, + wsize, + wstride, + wdilation, + pad1, + pad2); + } +#endif +#ifdef CK_ENABLE_FP32 + else if(data_type == ck::DataTypeEnum::Float) + { + ck::profiler::profile_avg_pool3d_bwd_impl(do_verification, + init_method, + do_log, + time_kernel, + in_length, + wsize, + wstride, + wdilation, + pad1, + pad2); + } +#endif + else + { + throw std::runtime_error("not implemented yet"); + } + + return 0; +} + +REGISTER_PROFILER_OPERATION("avg_pool3d_bwd", "max_pool bwd", profile_avg_pool3d_bwd); diff --git a/profiler/src/profile_batched_gemm.cpp b/profiler/src/profile_batched_gemm.cpp index 907a373794f2d8a59a066b94a591c4890c5c4038..222532b7bbdeb36bfa9953fb9f174baef3b548ac 100644 --- a/profiler/src/profile_batched_gemm.cpp +++ b/profiler/src/profile_batched_gemm.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #include #include @@ -10,6 +10,8 @@ #include "profiler/profile_batched_gemm_impl.hpp" #include "profiler_operation_registry.hpp" +#include "ck/library/tensor_operation_instance/gpu/batched_gemm.hpp" + enum struct GemmMatrixLayout { MK_KN_MN, // 0 @@ -78,55 +80,72 @@ int profile_batched_gemm(int argc, char* argv[]) using Row = ck::tensor_layout::gemm::RowMajor; using Col = ck::tensor_layout::gemm::ColumnMajor; - auto profile = [&](auto a_type, - auto b_type, - auto c_type, - auto a_layout, - auto b_layout, - auto c_layout) { - using ADataType = decltype(a_type); - using BDataType = decltype(b_type); - using CDataType = decltype(c_type); - - using ALayout = decltype(a_layout); - using BLayout = decltype(b_layout); - using CLayout = decltype(c_layout); - - const int DefaultStrideA = ck::is_same_v ? K : M; - const int DefaultStrideB = ck::is_same_v ? N : K; - const int DefaultStrideC = ck::is_same_v ? N : M; - - const int StrideA_ = (StrideA < 0) ? DefaultStrideA : StrideA; - const int StrideB_ = (StrideB < 0) ? DefaultStrideB : StrideB; - const int StrideC_ = (StrideC < 0) ? DefaultStrideC : StrideC; - - const int DefaultBatchStrideA = (ck::is_same_v ? M : K) * StrideA_; - const int DefaultBatchStrideB = (ck::is_same_v ? K : N) * StrideB_; - const int DefaultBatchStrideC = (ck::is_same_v ? M : N) * StrideC_; - - const int BatchStrideA_ = (BatchStrideA < 0) ? DefaultBatchStrideA : BatchStrideA; - const int BatchStrideB_ = (BatchStrideB < 0) ? DefaultBatchStrideB : BatchStrideB; - const int BatchStrideC_ = (BatchStrideC < 0) ? DefaultBatchStrideC : BatchStrideC; - - bool pass = ck::profiler:: - profile_batched_gemm_impl( - do_verification, - init_method, - do_log, - time_kernel, - M, - N, - K, - BatchStrideA_, - BatchStrideB_, - BatchStrideC_, - StrideA_, - StrideB_, - StrideC_, - BatchCount); - - return pass ? 0 : 1; - }; + auto profile = + [&](auto a_type, auto b_type, auto c_type, auto a_layout, auto b_layout, auto c_layout) { + using ADataType = decltype(a_type); + using BDataType = decltype(b_type); + using CDataType = decltype(c_type); + + using ALayout = decltype(a_layout); + using BLayout = decltype(b_layout); + using CLayout = decltype(c_layout); + + const int DefaultStrideA = ck::is_same_v ? K : M; + const int DefaultStrideB = ck::is_same_v ? N : K; + const int DefaultStrideC = ck::is_same_v ? N : M; + + const int StrideA_ = (StrideA < 0) ? DefaultStrideA : StrideA; + const int StrideB_ = (StrideB < 0) ? DefaultStrideB : StrideB; + const int StrideC_ = (StrideC < 0) ? DefaultStrideC : StrideC; + + const int DefaultBatchStrideA = (ck::is_same_v ? M : K) * StrideA_; + const int DefaultBatchStrideB = (ck::is_same_v ? K : N) * StrideB_; + const int DefaultBatchStrideC = (ck::is_same_v ? M : N) * StrideC_; + + const int BatchStrideA_ = (BatchStrideA < 0) ? DefaultBatchStrideA : BatchStrideA; + const int BatchStrideB_ = (BatchStrideB < 0) ? DefaultBatchStrideB : BatchStrideB; + const int BatchStrideC_ = (BatchStrideC < 0) ? DefaultBatchStrideC : BatchStrideC; + + using AElementOp = ck::tensor_operation::element_wise::PassThrough; + using BElementOp = ck::tensor_operation::element_wise::PassThrough; + using CElementOp = ck::tensor_operation::element_wise::PassThrough; + + using DeviceOp = ck::tensor_operation::device::DeviceBatchedGemm; + + bool pass = ck::profiler::profile_batched_gemm_impl(do_verification, + init_method, + do_log, + time_kernel, + M, + N, + K, + BatchStrideA_, + BatchStrideB_, + BatchStrideC_, + StrideA_, + StrideB_, + StrideC_, + BatchCount); + + return pass ? 0 : 1; + }; if(data_type == GemmDataType::F32_F32_F32 && layout == GemmMatrixLayout::MK_KN_MN) { diff --git a/profiler/src/profile_batched_gemm_add_relu_gemm_add.cpp b/profiler/src/profile_batched_gemm_add_relu_gemm_add.cpp index f440a3094eb16b4eef1c5dee46cb08f6ab25d933..3d29c4b84aa5abfafbb220c46205413adb534770 100644 --- a/profiler/src/profile_batched_gemm_add_relu_gemm_add.cpp +++ b/profiler/src/profile_batched_gemm_add_relu_gemm_add.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #include #include diff --git a/profiler/src/profile_batched_gemm_gemm.cpp b/profiler/src/profile_batched_gemm_gemm.cpp index 6015c93be35268c1c87c129d0cc7d2c9bdf0584f..9a99874d1c7041f692133a6ef2e792bdf0dd8105 100644 --- a/profiler/src/profile_batched_gemm_gemm.cpp +++ b/profiler/src/profile_batched_gemm_gemm.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #include #include diff --git a/profiler/src/profile_batched_gemm_multi_d.cpp b/profiler/src/profile_batched_gemm_multi_d.cpp new file mode 100644 index 0000000000000000000000000000000000000000..7cd4636d98ef5d3e71ff33b32eaf60b94a85d72f --- /dev/null +++ b/profiler/src/profile_batched_gemm_multi_d.cpp @@ -0,0 +1,194 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#include +#include +#include +#include +#include + +#include "profiler/profile_batched_gemm_impl.hpp" +#include "profiler_operation_registry.hpp" + +#include "ck/library/tensor_operation_instance/gpu/batched_gemm_multi_d.hpp" + +enum struct GemmMatrixLayout +{ + MK_KN_MN, // 0 + MK_NK_MN, // 1 + KM_KN_MN, // 2 + KM_NK_MN, // 3 +}; + +enum struct GemmDataType +{ + F16_F16_F16, // 0 + INT8_INT8_INT8, // 1 +}; + +#define OP_NAME "batched_gemm_multi_d" +#define OP_DESC "Batched GEMM multi D" + +int profile_batched_gemm_multi_d(int argc, char* argv[]) +{ + if(argc != 18) + { + // clang-format off + printf("arg1: tensor operation (" OP_NAME ": " OP_DESC ")\n"); + printf("arg2: data type (0: fp16; 1: int8)\n"); + printf("arg3: matrix layout (0: A[g, m, k] * B[g, k, n] = C[g, m, n];\n"); + printf(" 1: A[g, m, k] * B[g, n, k] = C[g, m, n];\n"); + printf(" 2: A[g, k, m] * B[g, k, n] = C[g, m, n];\n"); + printf(" 3: A[g, k, m] * B[g, n, k] = C[g, m, n])\n"); + printf("arg4: verification (0: no; 1: yes)\n"); + printf("arg5: initialization (0: no init; 1: integer value; 2: decimal value)\n"); + printf("arg6: print tensor value (0: no; 1: yes)\n"); + printf("arg7: time kernel (0=n0, 1=yes)\n"); + printf("arg8 to 17: M, N, K, StrideA, StrideB, StrideC, BatchStrideA, BatchStrideB, BatchStrideC, BatchCount\n"); + // clang-format on + exit(1); + } + + const auto data_type = static_cast(std::stoi(argv[2])); + const auto layout = static_cast(std::stoi(argv[3])); + const bool do_verification = std::stoi(argv[4]); + const int init_method = std::stoi(argv[5]); + const bool do_log = std::stoi(argv[6]); + const bool time_kernel = std::stoi(argv[7]); + + const int M = std::stoi(argv[8]); + const int N = std::stoi(argv[9]); + const int K = std::stoi(argv[10]); + + const int StrideA = std::stoi(argv[11]); + const int StrideB = std::stoi(argv[12]); + const int StrideC = std::stoi(argv[13]); + + const int BatchStrideA = std::stoi(argv[14]); + const int BatchStrideB = std::stoi(argv[15]); + const int BatchStrideC = std::stoi(argv[16]); + + const int BatchCount = std::stoi(argv[17]); + + using F16 = ck::half_t; +#ifdef CK_ENABLE_INT8 + using INT8 = int8_t; +#endif + + using Row = ck::tensor_layout::gemm::RowMajor; + using Col = ck::tensor_layout::gemm::ColumnMajor; + + auto profile = + [&](auto a_type, auto b_type, auto c_type, auto a_layout, auto b_layout, auto c_layout) { + using ADataType = decltype(a_type); + using BDataType = decltype(b_type); + using CDataType = decltype(c_type); + using DsDataType = ck::Tuple<>; + + using ALayout = decltype(a_layout); + using BLayout = decltype(b_layout); + using CLayout = decltype(c_layout); + using DsLayout = ck::Tuple<>; + + const int DefaultStrideA = ck::is_same_v ? K : M; + const int DefaultStrideB = ck::is_same_v ? N : K; + const int DefaultStrideC = ck::is_same_v ? N : M; + + const int StrideA_ = (StrideA < 0) ? DefaultStrideA : StrideA; + const int StrideB_ = (StrideB < 0) ? DefaultStrideB : StrideB; + const int StrideC_ = (StrideC < 0) ? DefaultStrideC : StrideC; + + const int DefaultBatchStrideA = (ck::is_same_v ? M : K) * StrideA_; + const int DefaultBatchStrideB = (ck::is_same_v ? K : N) * StrideB_; + const int DefaultBatchStrideC = (ck::is_same_v ? M : N) * StrideC_; + + const int BatchStrideA_ = (BatchStrideA < 0) ? DefaultBatchStrideA : BatchStrideA; + const int BatchStrideB_ = (BatchStrideB < 0) ? DefaultBatchStrideB : BatchStrideB; + const int BatchStrideC_ = (BatchStrideC < 0) ? DefaultBatchStrideC : BatchStrideC; + + using AElementOp = ck::tensor_operation::element_wise::PassThrough; + using BElementOp = ck::tensor_operation::element_wise::PassThrough; + using CElementOp = ck::tensor_operation::element_wise::PassThrough; + + using DeviceOp = ck::tensor_operation::device::DeviceBatchedGemmMultiD; + + bool pass = ck::profiler::profile_batched_gemm_impl(do_verification, + init_method, + do_log, + time_kernel, + M, + N, + K, + BatchStrideA_, + BatchStrideB_, + BatchStrideC_, + StrideA_, + StrideB_, + StrideC_, + BatchCount); + + return pass ? 0 : 1; + }; + + if(data_type == GemmDataType::F16_F16_F16 && layout == GemmMatrixLayout::MK_KN_MN) + { + return profile(F16{}, F16{}, F16{}, Row{}, Row{}, Row{}); + } + else if(data_type == GemmDataType::F16_F16_F16 && layout == GemmMatrixLayout::MK_NK_MN) + { + return profile(F16{}, F16{}, F16{}, Row{}, Col{}, Row{}); + } + else if(data_type == GemmDataType::F16_F16_F16 && layout == GemmMatrixLayout::KM_KN_MN) + { + return profile(F16{}, F16{}, F16{}, Col{}, Row{}, Row{}); + } + else if(data_type == GemmDataType::F16_F16_F16 && layout == GemmMatrixLayout::KM_NK_MN) + { + return profile(F16{}, F16{}, F16{}, Col{}, Col{}, Row{}); + } +#ifdef CK_ENABLE_INT8 + else if(data_type == GemmDataType::INT8_INT8_INT8 && layout == GemmMatrixLayout::MK_KN_MN) + { + return profile(INT8{}, INT8{}, INT8{}, Row{}, Row{}, Row{}); + } + else if(data_type == GemmDataType::INT8_INT8_INT8 && layout == GemmMatrixLayout::MK_NK_MN) + { + return profile(INT8{}, INT8{}, INT8{}, Row{}, Col{}, Row{}); + } + else if(data_type == GemmDataType::INT8_INT8_INT8 && layout == GemmMatrixLayout::KM_KN_MN) + { + return profile(INT8{}, INT8{}, INT8{}, Col{}, Row{}, Row{}); + } + else if(data_type == GemmDataType::INT8_INT8_INT8 && layout == GemmMatrixLayout::KM_NK_MN) + { + return profile(INT8{}, INT8{}, INT8{}, Col{}, Col{}, Row{}); + } +#endif + else + { + std::cout << "this data_type & layout is not implemented" << std::endl; + + return 1; + } +} + +REGISTER_PROFILER_OPERATION(OP_NAME, OP_DESC, profile_batched_gemm_multi_d); diff --git a/profiler/src/profile_batched_gemm_reduce.cpp b/profiler/src/profile_batched_gemm_reduce.cpp index 6b1dfc01427a1cd34eefc89584bb6f3ce8bf206f..9620d63cafbeedfb5398b553b5172418923bb8e0 100644 --- a/profiler/src/profile_batched_gemm_reduce.cpp +++ b/profiler/src/profile_batched_gemm_reduce.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #include #include diff --git a/profiler/src/profile_batchnorm_bwd.cpp b/profiler/src/profile_batchnorm_bwd.cpp index 44ce7350ff06eaad3eff7d93f60c022d1a4c2b60..1738d53dbe7c6d2581e15df8bae58f7c48733999 100644 --- a/profiler/src/profile_batchnorm_bwd.cpp +++ b/profiler/src/profile_batchnorm_bwd.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #include #include diff --git a/profiler/src/profile_batchnorm_fwd.cpp b/profiler/src/profile_batchnorm_fwd.cpp index 902a1fc423f98880b43e1551f3305854d1a0b9b8..507fb4b450ba2d7582346458e729d958908950c0 100644 --- a/profiler/src/profile_batchnorm_fwd.cpp +++ b/profiler/src/profile_batchnorm_fwd.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #include #include @@ -148,7 +148,7 @@ int profile_batchnorm_forward(int argc, char* argv[]) { if(arg_parser.inLengths.size() == 4 && arg_parser.reduceDims.size() == 3) { - profile_batchnorm_forward_impl( + profile_batchnorm_forward_impl( arg_parser.do_verification, arg_parser.init_method, arg_parser.do_dumpout, diff --git a/profiler/src/profile_batchnorm_infer.cpp b/profiler/src/profile_batchnorm_infer.cpp index 92c16859c1aa59ef814fa87166d3d27c9b43a3a2..f1c19bc36e8aded070b3ba843d9fa7007f37a55e 100644 --- a/profiler/src/profile_batchnorm_infer.cpp +++ b/profiler/src/profile_batchnorm_infer.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #include #include diff --git a/profiler/src/profile_conv_bwd_data.cpp b/profiler/src/profile_conv_bwd_data.cpp index 9241ead738ef8a8a67d8682d743f6c9bd0faf640..e08a39aeb02ba3950c459ae4a2c301533d071bc5 100644 --- a/profiler/src/profile_conv_bwd_data.cpp +++ b/profiler/src/profile_conv_bwd_data.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #include #include @@ -77,7 +77,9 @@ int profile_conv_bwd_data(int argc, char* argv[]) using F32 = float; using F16 = ck::half_t; using BF16 = ck::bhalf_t; +#ifdef CK_ENABLE_INT8 using INT8 = int8_t; +#endif using NWC = ck::tensor_layout::convolution::NWC; using NHWC = ck::tensor_layout::convolution::NHWC; @@ -138,10 +140,12 @@ int profile_conv_bwd_data(int argc, char* argv[]) { return profile(I1, NWC{}, KXC{}, NWK{}, BF16{}, BF16{}, BF16{}); } +#ifdef CK_ENABLE_INT8 else if(data_type == ConvDataType::INT8_INT8_INT8) { return profile(I1, NWC{}, KXC{}, NWK{}, INT8{}, INT8{}, INT8{}); } +#endif } else if(num_dim_spatial == 2 && layout == ConvLayout::NHWC_KYXC_NHWK) { @@ -157,10 +161,12 @@ int profile_conv_bwd_data(int argc, char* argv[]) { return profile(I2, NHWC{}, KYXC{}, NHWK{}, BF16{}, BF16{}, BF16{}); } +#ifdef CK_ENABLE_INT8 else if(data_type == ConvDataType::INT8_INT8_INT8) { return profile(I2, NHWC{}, KYXC{}, NHWK{}, INT8{}, INT8{}, INT8{}); } +#endif } else if(num_dim_spatial == 3 && layout == ConvLayout::NHWC_KYXC_NHWK) { @@ -176,10 +182,12 @@ int profile_conv_bwd_data(int argc, char* argv[]) { return profile(I3, NDHWC{}, KZYXC{}, NDHWK{}, BF16{}, BF16{}, BF16{}); } +#ifdef CK_ENABLE_INT8 else if(data_type == ConvDataType::INT8_INT8_INT8) { return profile(I3, NDHWC{}, KZYXC{}, NDHWK{}, INT8{}, INT8{}, INT8{}); } +#endif } std::cout << "this data_type & layout is not implemented" << std::endl; diff --git a/profiler/src/profile_conv_fwd.cpp b/profiler/src/profile_conv_fwd.cpp index b57ee7fd94261f4437d1b62f0b6450ea74522db9..701999d8a9fcd3718fc8d62e9b77f4f49159d474 100644 --- a/profiler/src/profile_conv_fwd.cpp +++ b/profiler/src/profile_conv_fwd.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #include #include diff --git a/profiler/src/profile_conv_fwd_bias_relu.cpp b/profiler/src/profile_conv_fwd_bias_relu.cpp index b44007cde4742efcec188f0b620e2cf83295f5b6..31055ec1d1530ff5fe9c89e89957fdd6e83f7efc 100644 --- a/profiler/src/profile_conv_fwd_bias_relu.cpp +++ b/profiler/src/profile_conv_fwd_bias_relu.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #include #include diff --git a/profiler/src/profile_conv_fwd_bias_relu_add.cpp b/profiler/src/profile_conv_fwd_bias_relu_add.cpp index 408dd02f78dd6241b5f0a67ae7aac4420c891943..8c2439a0c74be78b2d6ffcfdcce033c9ec0c7df5 100644 --- a/profiler/src/profile_conv_fwd_bias_relu_add.cpp +++ b/profiler/src/profile_conv_fwd_bias_relu_add.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #include #include diff --git a/profiler/src/profile_gemm.cpp b/profiler/src/profile_gemm.cpp index 61bae6ae70ebcf965618a100c3f553de73182118..9ca7fc4c88685b3f4537e6d992076f256ce2787a 100644 --- a/profiler/src/profile_gemm.cpp +++ b/profiler/src/profile_gemm.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #include #include @@ -67,11 +67,15 @@ int profile_gemm(int argc, char* argv[]) const int StrideB = std::stoi(argv[12]); const int StrideC = std::stoi(argv[13]); - using F32 = float; - using F16 = ck::half_t; - using BF16 = ck::bhalf_t; + using F32 = float; + using F16 = ck::half_t; +#ifdef CK_ENABLE_BF16 + using BF16 = ck::bhalf_t; +#endif +#ifdef CK_ENABLE_INT8 using INT8 = int8_t; using INT32 = int32_t; +#endif using Row = ck::tensor_layout::gemm::RowMajor; using Col = ck::tensor_layout::gemm::ColumnMajor; @@ -117,7 +121,10 @@ int profile_gemm(int argc, char* argv[]) return pass ? 0 : 1; }; - if(data_type == GemmDataType::F32_F32_F32 && layout == GemmMatrixLayout::MK_KN_MN) + if(false) + ; +#ifdef CK_ENABLE_FP32 + else if(data_type == GemmDataType::F32_F32_F32 && layout == GemmMatrixLayout::MK_KN_MN) { return profile(Row{}, Row{}, Row{}, F32{}, F32{}, F32{}, F32{}); } @@ -133,6 +140,8 @@ int profile_gemm(int argc, char* argv[]) { return profile(Col{}, Col{}, Row{}, F32{}, F32{}, F32{}, F32{}); } +#endif +#ifdef CK_ENABLE_FP16 else if(data_type == GemmDataType::F16_F16_F16 && layout == GemmMatrixLayout::MK_KN_MN) { return profile(Row{}, Row{}, Row{}, F16{}, F16{}, F32{}, F16{}); @@ -149,6 +158,8 @@ int profile_gemm(int argc, char* argv[]) { return profile(Col{}, Col{}, Row{}, F16{}, F16{}, F32{}, F16{}); } +#endif +#ifdef CK_ENABLE_BF16 else if(data_type == GemmDataType::BF16_BF16_BF16 && layout == GemmMatrixLayout::MK_KN_MN) { return profile(Row{}, Row{}, Row{}, BF16{}, BF16{}, F32{}, BF16{}); @@ -165,6 +176,8 @@ int profile_gemm(int argc, char* argv[]) { return profile(Col{}, Col{}, Row{}, BF16{}, BF16{}, F32{}, BF16{}); } +#endif +#ifdef CK_ENABLE_INT8 else if(data_type == GemmDataType::INT8_INT8_INT8 && layout == GemmMatrixLayout::MK_KN_MN) { return profile(Row{}, Row{}, Row{}, INT8{}, INT8{}, INT32{}, INT8{}); @@ -181,6 +194,7 @@ int profile_gemm(int argc, char* argv[]) { return profile(Col{}, Col{}, Row{}, INT8{}, INT8{}, INT32{}, INT8{}); } +#endif else { std::cout << "this data_type & layout is not implemented" << std::endl; diff --git a/profiler/src/profile_gemm_add_add_fastgelu.cpp b/profiler/src/profile_gemm_add_add_fastgelu.cpp index c3c0fb7b67daf131ab53194bbf28a12aeedf62aa..8af3768a48aa9905700b55ab73e83777f20f8223 100644 --- a/profiler/src/profile_gemm_add_add_fastgelu.cpp +++ b/profiler/src/profile_gemm_add_add_fastgelu.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #include #include diff --git a/profiler/src/profile_gemm_add_fastgelu.cpp b/profiler/src/profile_gemm_add_fastgelu.cpp index 380b25a614c58d42209135218cca331de3bebe33..a09bb8340d3268ed642b1cd15f318f9458867c99 100644 --- a/profiler/src/profile_gemm_add_fastgelu.cpp +++ b/profiler/src/profile_gemm_add_fastgelu.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #include #include diff --git a/profiler/src/profile_gemm_add_multiply.cpp b/profiler/src/profile_gemm_add_multiply.cpp index 7d6fead402f22a92d3b2c12b8d612c22582504ee..560467c264f2bcc8e406ed02de742c06942e27e0 100644 --- a/profiler/src/profile_gemm_add_multiply.cpp +++ b/profiler/src/profile_gemm_add_multiply.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #include #include diff --git a/profiler/src/profile_gemm_add_relu_add_layernorm.cpp b/profiler/src/profile_gemm_add_relu_add_layernorm.cpp index 5cbc3d21f8a335aaaf0353308a2fc996e210b520..558d255ce110565e90f6eff5fae86ce6bf6447bb 100644 --- a/profiler/src/profile_gemm_add_relu_add_layernorm.cpp +++ b/profiler/src/profile_gemm_add_relu_add_layernorm.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #include #include diff --git a/profiler/src/profile_gemm_bias_add_reduce.cpp b/profiler/src/profile_gemm_bias_add_reduce.cpp index 6d86db08223a771ddc4373c014a6a2a42f92285b..76daffbc67cb95ac19d52a6791d90816a5e6656b 100644 --- a/profiler/src/profile_gemm_bias_add_reduce.cpp +++ b/profiler/src/profile_gemm_bias_add_reduce.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #include #include diff --git a/profiler/src/profile_gemm_bilinear.cpp b/profiler/src/profile_gemm_bilinear.cpp index bfeef4adb185f32fc6623a18df3bff5cad6654ca..4527a2fa00229944afe6886c3d239b5f0cab0495 100644 --- a/profiler/src/profile_gemm_bilinear.cpp +++ b/profiler/src/profile_gemm_bilinear.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #include #include diff --git a/profiler/src/profile_gemm_fastgelu.cpp b/profiler/src/profile_gemm_fastgelu.cpp index 2a137224cb096f4bf2968f06be13d71d44614794..93573002ef1c367a00d42dc2ea0fe99fca98086a 100644 --- a/profiler/src/profile_gemm_fastgelu.cpp +++ b/profiler/src/profile_gemm_fastgelu.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #include #include diff --git a/profiler/src/profile_gemm_multiply_add.cpp b/profiler/src/profile_gemm_multiply_add.cpp new file mode 100644 index 0000000000000000000000000000000000000000..fd1f5c65c1ffdb7fa6c8a8481781964f5a7bd922 --- /dev/null +++ b/profiler/src/profile_gemm_multiply_add.cpp @@ -0,0 +1,153 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#include +#include +#include +#include + +#include "profiler/profile_gemm_multiply_add_impl.hpp" +#include "profiler_operation_registry.hpp" + +#define OP_NAME "gemm_multiply_add" +#define OP_DESC "GEMM+MULTIPLY+ADD" + +int profile_gemm_multiply_add(int argc, char* argv[]) +{ + enum struct MatrixLayout + { + MK_KN_MN_MN_MN, // 0 + MK_NK_MN_MN_MN, // 1 + }; + + enum struct MatrixDataType + { + F16_F16_F16_F16_F16, // 0 + F16_F8_F32_F32_F16, // 1 + }; + + if(argc != 16) + { + // clang-format off + printf("arg1: tensor operation (" OP_NAME ": " OP_DESC ")\n"); + printf("arg2: data type (0: fp16; 1: fp16Afp8B)\n"); + printf("arg3: matrix layout (0: E[m, n] = Multiply_Add((A[m, k] * B[k, n]) x D1[m, n] + D0[m, n]);\n"); + printf(" 1: E[m, n] = Multiply_Add((A[m, k] * B[n, k]) x D1[m, n] + D0[m, n]);\n"); + printf("arg4: verification (0: no; 1: yes)\n"); + printf("arg5: initialization (0: no init; 1: integer value; 2: decimal value)\n"); + printf("arg6: print tensor value (0: no; 1: yes)\n"); + printf("arg7: time kernel (0=no, 1=yes)\n"); + printf("arg8 to 15: M, N, K, StrideA, StrideB, StrideD0, StrideD1, StrideE\n"); + // clang-format on + exit(1); + } + + const auto data_type = static_cast(std::stoi(argv[2])); + const auto layout = static_cast(std::stoi(argv[3])); + const bool do_verification = std::stoi(argv[4]); + const int init_method = std::stoi(argv[5]); + const bool do_log = std::stoi(argv[6]); + const bool time_kernel = std::stoi(argv[7]); + + const int M = std::stoi(argv[8]); + const int N = std::stoi(argv[9]); + const int K = std::stoi(argv[10]); + + const int StrideA = std::stoi(argv[11]); + const int StrideB = std::stoi(argv[12]); + const int StrideD0 = std::stoi(argv[13]); + const int StrideD1 = std::stoi(argv[14]); + const int StrideE = std::stoi(argv[15]); + + using F8 = ck::f8_t; + using F16 = ck::half_t; + using F32 = float; + + using Row = ck::tensor_layout::gemm::RowMajor; + using Col = ck::tensor_layout::gemm::ColumnMajor; + + auto profile = [&](auto a_type, + auto b_type, + auto acc_type, + auto d0_type, + auto d1_type, + auto e_type, + auto a_layout, + auto b_layout, + auto d0_layout, + auto d1_layout, + auto e_layout) { + using ADataType = decltype(a_type); + using BDataType = decltype(b_type); + using AccDataType = decltype(acc_type); + using D0DataType = decltype(d0_type); + using D1DataType = decltype(d1_type); + using EDataType = decltype(e_type); + + using ALayout = decltype(a_layout); + using BLayout = decltype(b_layout); + using D0Layout = decltype(d0_layout); + using D1Layout = decltype(d1_layout); + using ELayout = decltype(e_layout); + + const int DefaultStrideA = ck::is_same_v ? K : M; + const int DefaultStrideB = ck::is_same_v ? N : K; + const int DefaultStrideD0 = ck::is_same_v ? N : M; + const int DefaultStrideD1 = ck::is_same_v ? N : M; + const int DefaultStrideE = ck::is_same_v ? N : M; + + bool pass = ck::profiler::profile_gemm_multiply_add_impl( + do_verification, + init_method, + do_log, + time_kernel, + M, + N, + K, + (StrideA < 0) ? DefaultStrideA : StrideA, + (StrideB < 0) ? DefaultStrideB : StrideB, + (StrideD0 < 0) ? DefaultStrideD0 : StrideD0, + (StrideD1 < 0) ? DefaultStrideD1 : StrideD1, + (StrideE < 0) ? DefaultStrideE : StrideE); + + return pass ? 0 : 1; + }; + + if(data_type == MatrixDataType::F16_F16_F16_F16_F16 && layout == MatrixLayout::MK_KN_MN_MN_MN) + { + return profile(F16{}, F16{}, F32{}, F16{}, F16{}, F16{}, Row{}, Row{}, Row{}, Row{}, Row{}); + } + else if(data_type == MatrixDataType::F16_F16_F16_F16_F16 && + layout == MatrixLayout::MK_NK_MN_MN_MN) + { + return profile(F16{}, F16{}, F32{}, F16{}, F16{}, F16{}, Row{}, Col{}, Row{}, Row{}, Row{}); + } + else if(data_type == MatrixDataType::F16_F8_F32_F32_F16 && + layout == MatrixLayout::MK_KN_MN_MN_MN) + { + return profile(F16{}, F8{}, F32{}, F32{}, F32{}, F16{}, Row{}, Row{}, Row{}, Row{}, Row{}); + } + else if(data_type == MatrixDataType::F16_F8_F32_F32_F16 && + layout == MatrixLayout::MK_NK_MN_MN_MN) + { + return profile(F16{}, F8{}, F32{}, F32{}, F32{}, F16{}, Row{}, Col{}, Row{}, Row{}, Row{}); + } + else + { + std::cout << "this data_type & layout is not implemented" << std::endl; + + return 1; + } +} + +REGISTER_PROFILER_OPERATION(OP_NAME, OP_DESC, profile_gemm_multiply_add); diff --git a/profiler/src/profile_gemm_reduce.cpp b/profiler/src/profile_gemm_reduce.cpp index 395bf0627e43617d04d30a903a1f3cbdecf4a557..48f6f5eb49da151136600c2d03b99d959d51565f 100644 --- a/profiler/src/profile_gemm_reduce.cpp +++ b/profiler/src/profile_gemm_reduce.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #include #include diff --git a/profiler/src/profile_gemm_splitk.cpp b/profiler/src/profile_gemm_splitk.cpp index f636ce718c669feb23a60fb1853a1219fea3da55..617e0b9cd472b196ec7a847db7f7f224a5d4c781 100644 --- a/profiler/src/profile_gemm_splitk.cpp +++ b/profiler/src/profile_gemm_splitk.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #include #include @@ -23,6 +23,8 @@ enum struct GemmDataType F16_F16_F16, // 1 BF16_BF16_BF16, // 2 INT8_INT8_INT8, // 3 + F8_F16_F16, // 4 + F16_F8_F16, // 5 }; #define OP_NAME "gemm_splitk" @@ -33,7 +35,7 @@ int profile_gemm_splitk(int argc, char* argv[]) if(argc != 15) { printf("arg1: tensor operation (" OP_NAME ": " OP_DESC ")\n"); - printf("arg2: data type (0: fp32; 1: fp16; 2: bf16; 3: int8)\n"); + printf("arg2: data type (0: fp32; 1: fp16; 2: bf16; 3: int8; 4: f8@f16; 5: f16@f8)\n"); printf("arg3: matrix layout (0: A[m, k] * B[k, n] = C[m, n];\n"); printf(" 1: A[m, k] * B[n, k] = C[m, n];\n"); printf(" 2: A[k, m] * B[k, n] = C[m, n];\n"); @@ -65,6 +67,7 @@ int profile_gemm_splitk(int argc, char* argv[]) using F32 = float; using F16 = ck::half_t; + using F8 = ck::f8_t; using Row = ck::tensor_layout::gemm::RowMajor; using Col = ck::tensor_layout::gemm::ColumnMajor; @@ -143,6 +146,38 @@ int profile_gemm_splitk(int argc, char* argv[]) { return profile(F16{}, F16{}, F32{}, F16{}, Col{}, Col{}, Row{}); } + else if(data_type == GemmDataType::F8_F16_F16 && layout == GemmMatrixLayout::MK_KN_MN) + { + return profile(F8{}, F16{}, F32{}, F16{}, Row{}, Row{}, Row{}); + } + else if(data_type == GemmDataType::F8_F16_F16 && layout == GemmMatrixLayout::MK_NK_MN) + { + return profile(F8{}, F16{}, F32{}, F16{}, Row{}, Col{}, Row{}); + } + else if(data_type == GemmDataType::F8_F16_F16 && layout == GemmMatrixLayout::KM_KN_MN) + { + return profile(F8{}, F16{}, F32{}, F16{}, Col{}, Row{}, Row{}); + } + else if(data_type == GemmDataType::F8_F16_F16 && layout == GemmMatrixLayout::KM_NK_MN) + { + return profile(F8{}, F16{}, F32{}, F16{}, Col{}, Col{}, Row{}); + } + else if(data_type == GemmDataType::F16_F8_F16 && layout == GemmMatrixLayout::MK_KN_MN) + { + return profile(F16{}, F8{}, F32{}, F16{}, Row{}, Row{}, Row{}); + } + else if(data_type == GemmDataType::F16_F8_F16 && layout == GemmMatrixLayout::MK_NK_MN) + { + return profile(F16{}, F8{}, F32{}, F16{}, Row{}, Col{}, Row{}); + } + else if(data_type == GemmDataType::F16_F8_F16 && layout == GemmMatrixLayout::KM_KN_MN) + { + return profile(F16{}, F8{}, F32{}, F16{}, Col{}, Row{}, Row{}); + } + else if(data_type == GemmDataType::F16_F8_F16 && layout == GemmMatrixLayout::KM_NK_MN) + { + return profile(F16{}, F8{}, F32{}, F16{}, Col{}, Col{}, Row{}); + } else { std::cout << "this data_type & layout is not implemented" << std::endl; diff --git a/profiler/src/profile_gemm_streamk.cpp b/profiler/src/profile_gemm_streamk.cpp new file mode 100644 index 0000000000000000000000000000000000000000..a0a49eb36d241d83e6bf76229c7799eb3fc2562c --- /dev/null +++ b/profiler/src/profile_gemm_streamk.cpp @@ -0,0 +1,155 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#include +#include +#include +#include + +#include "profiler/profile_gemm_streamk_impl.hpp" +#include "profiler_operation_registry.hpp" + +enum struct GemmMatrixLayout +{ + MK_KN_MN, // 0 + MK_NK_MN, // 1 + KM_KN_MN, // 2 + KM_NK_MN, // 3 +}; + +enum struct GemmDataType +{ + F32_F32_F32, // 0 + F16_F16_F16, // 1 + BF16_BF16_BF16, // 2 + INT8_INT8_INT8, // 3 +}; + +#define OP_NAME "gemm_streamk" +#define OP_DESC "StreamK GEMM" + +int profile_gemm_streamk(int argc, char* argv[]) +{ + if(argc < 14) + { + printf("arg1: tensor operation (" OP_NAME ": " OP_DESC ")\n"); + printf("arg2: data type (0: fp32; 1: fp16; 2: bf16; 3: int8)\n"); + printf("arg3: matrix layout (0: A[m, k] * B[k, n] = C[m, n];\n"); + printf(" 1: A[m, k] * B[n, k] = C[m, n];\n"); + printf(" 2: A[k, m] * B[k, n] = C[m, n];\n"); + printf(" 3: A[k, m] * B[n, k] = C[m, n])\n"); + printf("arg4: verification (0: no; 1: yes)\n"); + printf("arg5: initialization (0: no init; 1: integer value; 2: decimal value)\n"); + printf("arg6: print tensor value (0: no; 1: yes)\n"); + printf("arg7: time kernel (0=no, 1=yes)\n"); + printf("arg8 to 13: M, N, K, StrideA, StrideB, StrideC\n"); + printf("arg14: num_sk_blocks (optional)\n"); + exit(1); + } + + const auto data_type = static_cast(std::stoi(argv[2])); + const auto layout = static_cast(std::stoi(argv[3])); + const bool do_verification = std::stoi(argv[4]); + const int init_method = std::stoi(argv[5]); + const bool do_log = std::stoi(argv[6]); + const bool time_kernel = std::stoi(argv[7]); + + const int M = std::stoi(argv[8]); + const int N = std::stoi(argv[9]); + const int K = std::stoi(argv[10]); + + const int StrideA = std::stoi(argv[11]); + const int StrideB = std::stoi(argv[12]); + const int StrideC = std::stoi(argv[13]); + const uint32_t NumSKBlocks = + argc >= 15 ? static_cast(std::stoul(std::string(argv[14]))) : 0xffffffff; + + using F32 = float; + using F16 = ck::half_t; + + using Row = ck::tensor_layout::gemm::RowMajor; + using Col = ck::tensor_layout::gemm::ColumnMajor; + + auto profile = [&](auto a_type, + auto b_type, + auto acc_type, + auto c_type, + auto a_layout, + auto b_layout, + auto c_layout) { + using ADataType = decltype(a_type); + using BDataType = decltype(b_type); + using AccDataType = decltype(acc_type); + using CDataType = decltype(c_type); + + using ALayout = decltype(a_layout); + using BLayout = decltype(b_layout); + using CLayout = decltype(c_layout); + + const int DefaultStrideA = ck::is_same_v ? K : M; + const int DefaultStrideB = ck::is_same_v ? N : K; + const int DefaultStrideC = ck::is_same_v ? N : M; + + bool pass = ck::profiler::profile_gemm_streamk_impl( + do_verification, + init_method, + do_log, + time_kernel, + M, + N, + K, + (StrideA <= 0) ? DefaultStrideA : StrideA, + (StrideB <= 0) ? DefaultStrideB : StrideB, + (StrideC <= 0) ? DefaultStrideC : StrideC, + NumSKBlocks); + + return pass ? 0 : 1; + }; + + if(data_type == GemmDataType::F32_F32_F32 && layout == GemmMatrixLayout::MK_KN_MN) + { + return profile(F32{}, F32{}, F32{}, F32{}, Row{}, Row{}, Row{}); + } + else if(data_type == GemmDataType::F32_F32_F32 && layout == GemmMatrixLayout::MK_NK_MN) + { + return profile(F32{}, F32{}, F32{}, F32{}, Row{}, Col{}, Row{}); + } + else if(data_type == GemmDataType::F32_F32_F32 && layout == GemmMatrixLayout::KM_KN_MN) + { + return profile(F32{}, F32{}, F32{}, F32{}, Col{}, Row{}, Row{}); + } + else if(data_type == GemmDataType::F32_F32_F32 && layout == GemmMatrixLayout::KM_NK_MN) + { + return profile(F32{}, F32{}, F32{}, F32{}, Col{}, Col{}, Row{}); + } + else if(data_type == GemmDataType::F16_F16_F16 && layout == GemmMatrixLayout::MK_KN_MN) + { + return profile(F16{}, F16{}, F32{}, F16{}, Row{}, Row{}, Row{}); + } + else if(data_type == GemmDataType::F16_F16_F16 && layout == GemmMatrixLayout::MK_NK_MN) + { + return profile(F16{}, F16{}, F32{}, F16{}, Row{}, Col{}, Row{}); + } + else if(data_type == GemmDataType::F16_F16_F16 && layout == GemmMatrixLayout::KM_KN_MN) + { + return profile(F16{}, F16{}, F32{}, F16{}, Col{}, Row{}, Row{}); + } + else if(data_type == GemmDataType::F16_F16_F16 && layout == GemmMatrixLayout::KM_NK_MN) + { + return profile(F16{}, F16{}, F32{}, F16{}, Col{}, Col{}, Row{}); + } + else + { + std::cout << "this data_type & layout is not implemented" << std::endl; + + return 1; + } +} + +REGISTER_PROFILER_OPERATION(OP_NAME, OP_DESC, profile_gemm_streamk); diff --git a/profiler/src/profile_grouped_conv_bwd_data.cpp b/profiler/src/profile_grouped_conv_bwd_data.cpp new file mode 100644 index 0000000000000000000000000000000000000000..55d199317a8dc83faa8beb7f7bebe332f6d2c53c --- /dev/null +++ b/profiler/src/profile_grouped_conv_bwd_data.cpp @@ -0,0 +1,186 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#include +#include +#include +#include + +#include "profiler/profile_grouped_conv_bwd_data_impl.hpp" +#include "profiler_operation_registry.hpp" + +namespace { + +enum struct ConvLayout +{ + GNHWC_GKYXC_GNHWK, // 0 + NHWGC_GKYXC_NHWGK, // 1 +}; + +enum struct ConvDataType +{ + F32_F32_F32, // 0 + F16_F16_F16, // 1 + BF16_BF16_BF16, // 2 +}; + +#define OP_NAME "grouped_conv_bwd_data" +#define OP_DESC "Grouped Convolution Backward Data" + +static void print_helper_msg() +{ + std::cout + // clang-format off + << "arg1: tensor operation (" OP_NAME ": " OP_DESC ")\n" + << "arg2: data type (0: Output fp32, Weight fp32, Input fp32\n" + << " 1: Output fp16, Weight fp16, Input fp16\n" + << " 2: Output bf16, Weight bf16, Input bf16\n" + << "arg3: tensor layout (0: Output[G, N, Hi, Wi, C], Weight[G, K, Y, X, C], Input[G, N, Ho, Wo, K]\n" + << " 1: Output[N, Hi, Wi, G, C], Weight[G, K, Y, X, C], Input[N, Ho, Wo, G, K])\n" + << "arg4: verification (0: no, 1: yes)\n" + << "arg5: initialization (0: no init, 1: integer value, 2: decimal value)\n" + << "arg6: print tensor value (0: no; 1: yes)\n" + << "arg7: time kernel (0: no, 1: yes)\n" + << ck::utils::conv::get_conv_param_parser_helper_msg() << std::endl; + // clang-format on +} + +} // namespace + +int profile_grouped_conv_bwd_data(int argc, char* argv[]) +{ + // 8 for control, 1 for num_dim_spatial + if(argc < 9) + { + print_helper_msg(); + return 1; + } + + const auto data_type = static_cast(std::stoi(argv[2])); + const auto layout = static_cast(std::stoi(argv[3])); + const bool do_verification = std::stoi(argv[4]); + const int init_method = std::stoi(argv[5]); + const bool do_log = std::stoi(argv[6]); + const bool time_kernel = std::stoi(argv[7]); + const int num_dim_spatial = std::stoi(argv[8]); + + // 8 for control, 1 for num_dim_spatial, 4 for G/N/K/C, and 6 * num_dim_spatial + if(argc != 8 + 1 + 4 + 6 * num_dim_spatial) + { + print_helper_msg(); + return 1; + } + + const auto params = ck::utils::conv::parse_conv_param(num_dim_spatial, 9, argv); + + using F32 = float; + using F16 = ck::half_t; + using BF16 = ck::bhalf_t; + + using namespace ck::tensor_layout::convolution; + + constexpr auto I2 = ck::Number<2>{}; + constexpr auto I3 = ck::Number<3>{}; + + auto profile = [&](auto num_dim_spatial_tmp, + auto out_layout, + auto wei_layout, + auto in_layout, + auto wei_type, + auto out_type, + auto in_type) { + constexpr ck::index_t NDimSpatial = num_dim_spatial_tmp.value; + + using OutLayout = decltype(out_layout); + using WeiLayout = decltype(wei_layout); + using InLayout = decltype(in_layout); + + using OutDataType = decltype(out_type); + using WeiDataType = decltype(wei_type); + using InDataType = decltype(in_type); + + bool pass = ck::profiler::profile_grouped_conv_bwd_data_impl( + do_verification, init_method, do_log, time_kernel, params); + + return pass ? 0 : 1; + }; + + if(num_dim_spatial == 2) + { + if(layout == ConvLayout::GNHWC_GKYXC_GNHWK) + { + if(data_type == ConvDataType::F32_F32_F32) + { + return profile(I2, GNHWK{}, GKYXC{}, GNHWC{}, F32{}, F32{}, F32{}); + } + else if(data_type == ConvDataType::F16_F16_F16) + { + return profile(I2, GNHWK{}, GKYXC{}, GNHWC{}, F16{}, F16{}, F16{}); + } + else if(data_type == ConvDataType::BF16_BF16_BF16) + { + return profile(I2, GNHWK{}, GKYXC{}, GNHWC{}, BF16{}, BF16{}, BF16{}); + } + } + else if(layout == ConvLayout::NHWGC_GKYXC_NHWGK) + { + if(data_type == ConvDataType::F32_F32_F32) + { + return profile(I2, NHWGK{}, GKYXC{}, NHWGC{}, F32{}, F32{}, F32{}); + } + else if(data_type == ConvDataType::F16_F16_F16) + { + return profile(I2, NHWGK{}, GKYXC{}, NHWGC{}, F16{}, F16{}, F16{}); + } + else if(data_type == ConvDataType::BF16_BF16_BF16) + { + return profile(I2, NHWGK{}, GKYXC{}, NHWGC{}, BF16{}, BF16{}, BF16{}); + } + } + } + else if(num_dim_spatial == 3) + { + if(layout == ConvLayout::GNHWC_GKYXC_GNHWK) + { + if(data_type == ConvDataType::F32_F32_F32) + { + return profile(I3, GNDHWK{}, GKZYXC{}, GNDHWC{}, F32{}, F32{}, F32{}); + } + else if(data_type == ConvDataType::F16_F16_F16) + { + return profile(I3, GNDHWK{}, GKZYXC{}, GNDHWC{}, F16{}, F16{}, F16{}); + } + else if(data_type == ConvDataType::BF16_BF16_BF16) + { + return profile(I3, GNDHWK{}, GKZYXC{}, GNDHWC{}, BF16{}, BF16{}, BF16{}); + } + } + else if(layout == ConvLayout::NHWGC_GKYXC_NHWGK) + { + if(data_type == ConvDataType::F32_F32_F32) + { + return profile(I3, NDHWGK{}, GKZYXC{}, NDHWGC{}, F32{}, F32{}, F32{}); + } + else if(data_type == ConvDataType::F16_F16_F16) + { + return profile(I3, NDHWGK{}, GKZYXC{}, NDHWGC{}, F16{}, F16{}, F16{}); + } + else if(data_type == ConvDataType::BF16_BF16_BF16) + { + return profile(I3, NDHWGK{}, GKZYXC{}, NDHWGC{}, BF16{}, BF16{}, BF16{}); + } + } + } + + std::cout << "this data_type & layout is not implemented" << std::endl; + + return 1; +} + +REGISTER_PROFILER_OPERATION(OP_NAME, OP_DESC, profile_grouped_conv_bwd_data); diff --git a/profiler/src/profile_grouped_conv_bwd_weight.cpp b/profiler/src/profile_grouped_conv_bwd_weight.cpp index dfd8a099f5408f76fd51a9c69a988591edf5348f..be8a3230f76ed224a2ecf65cd84a3a700f7df898 100644 --- a/profiler/src/profile_grouped_conv_bwd_weight.cpp +++ b/profiler/src/profile_grouped_conv_bwd_weight.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #include #include @@ -15,6 +15,7 @@ enum struct ConvLayout { GNCHW_GKCYX_GNKHW, // 0 GNHWC_GKYXC_GNHWK, // 1 + NHWGC_GKYXC_NHWGK, // 2 }; enum struct ConvDataType @@ -37,6 +38,8 @@ static void print_helper_msg() "N, K, Ho, Wo]\n" << " 1: Input[G, N, Hi, Wi, C], Weight[G, K, Y, X, C], Output[G, " "N, Ho, Wo, K]\n" + << " 2: Input[N, Hi, Wi, G, C], Weight[G, K, Y, X, C], Output[N, " + "Ho, Wo, G, K]\n" << "arg4: verification (0: no, 1: yes)\n" << "arg5: initialization (0: no init, 1: integer value, 2: decimal value)\n" << "arg6: print tensor value (0: no; 1: yes)\n" @@ -80,17 +83,7 @@ int profile_grouped_conv_bwd_weight(int argc, char* argv[]) using F16 = ck::half_t; using BF16 = ck::bhalf_t; - using GNWC = ck::tensor_layout::convolution::GNWC; - using GNHWC = ck::tensor_layout::convolution::GNHWC; - using GNDHWC = ck::tensor_layout::convolution::GNDHWC; - - using GKXC = ck::tensor_layout::convolution::GKXC; - using GKYXC = ck::tensor_layout::convolution::GKYXC; - using GKZYXC = ck::tensor_layout::convolution::GKZYXC; - - using GNWK = ck::tensor_layout::convolution::GNWK; - using GNHWK = ck::tensor_layout::convolution::GNHWK; - using GNDHWK = ck::tensor_layout::convolution::GNDHWK; + using namespace ck::tensor_layout::convolution; constexpr auto I1 = ck::Number<1>{}; constexpr auto I2 = ck::Number<2>{}; @@ -157,6 +150,22 @@ int profile_grouped_conv_bwd_weight(int argc, char* argv[]) return profile(I2, GNHWC{}, GKYXC{}, GNHWK{}, BF16{}, F32{}, BF16{}); } } + else if(num_dim_spatial == 2 && layout == ConvLayout::NHWGC_GKYXC_NHWGK) + { + if(data_type == ConvDataType::F32_F32_F32) + { + return profile(I2, NHWGC{}, GKYXC{}, NHWGK{}, F32{}, F32{}, F32{}); + } + else if(data_type == ConvDataType::F16_F16_F16) + { + return profile(I2, NHWGC{}, GKYXC{}, NHWGK{}, F16{}, F16{}, F16{}); + } + else if(data_type == ConvDataType::BF16_F32_BF16) + { + // fp32 atomic add is used for weight tensor in bf16 kernel + return profile(I2, NHWGC{}, GKYXC{}, NHWGK{}, BF16{}, F32{}, BF16{}); + } + } else if(num_dim_spatial == 3 && layout == ConvLayout::GNHWC_GKYXC_GNHWK) { if(data_type == ConvDataType::F32_F32_F32) @@ -173,6 +182,22 @@ int profile_grouped_conv_bwd_weight(int argc, char* argv[]) return profile(I3, GNDHWC{}, GKZYXC{}, GNDHWK{}, BF16{}, F32{}, BF16{}); } } + else if(num_dim_spatial == 3 && layout == ConvLayout::NHWGC_GKYXC_NHWGK) + { + if(data_type == ConvDataType::F32_F32_F32) + { + return profile(I3, NDHWGC{}, GKZYXC{}, NDHWGK{}, F32{}, F32{}, F32{}); + } + else if(data_type == ConvDataType::F16_F16_F16) + { + return profile(I3, NDHWGC{}, GKZYXC{}, NDHWGK{}, F16{}, F16{}, F16{}); + } + else if(data_type == ConvDataType::BF16_F32_BF16) + { + // fp32 atomic add is used for weight tensor in bf16 kernel + return profile(I3, NDHWGC{}, GKZYXC{}, NDHWGK{}, BF16{}, F32{}, BF16{}); + } + } std::cout << "this data_type & layout is not implemented" << std::endl; diff --git a/profiler/src/profile_grouped_conv_fwd.cpp b/profiler/src/profile_grouped_conv_fwd.cpp index 9ff3c15af05a5f67988ef4842e00e2671289bb1b..d0b424cde64249aff7da0646a6bb11e83ccecf39 100644 --- a/profiler/src/profile_grouped_conv_fwd.cpp +++ b/profiler/src/profile_grouped_conv_fwd.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #include #include diff --git a/profiler/src/profile_grouped_gemm.cpp b/profiler/src/profile_grouped_gemm.cpp index 34647adabfea27f96d76147e58799847727357ab..5636656ba3972caa436a64681af9b18f84bb6ebc 100644 --- a/profiler/src/profile_grouped_gemm.cpp +++ b/profiler/src/profile_grouped_gemm.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #include #include @@ -88,7 +88,7 @@ int profile_grouped_gemm(int argc, char* argv[]) const auto StrideBs = argToIntArray(argv[12]); const auto StrideCs = argToIntArray(argv[13]); const int kbatch = argc == 15 ? std::stoi(argv[14]) : 1; - +#ifdef CK_ENABLE_FP16 if(data_type == GemmDataType::F16_F16_F16 && layout == GemmMatrixLayout::MK_KN_MN) { ck::profiler::profile_grouped_gemm_impl #include diff --git a/profiler/src/profile_groupnorm.cpp b/profiler/src/profile_groupnorm.cpp index d55529a0f033c36de7951b383a62c5ac656317f5..d55784ff0ad7274fce9bbf66b26a508f5994a578 100644 --- a/profiler/src/profile_groupnorm.cpp +++ b/profiler/src/profile_groupnorm.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #include #include diff --git a/profiler/src/profile_image_to_column.cpp b/profiler/src/profile_image_to_column.cpp new file mode 100644 index 0000000000000000000000000000000000000000..bf4312a6cfde8da8b39d5034e18c841cde159129 --- /dev/null +++ b/profiler/src/profile_image_to_column.cpp @@ -0,0 +1,169 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#include +#include +#include +#include + +#include "profiler/profile_image_to_column_impl.hpp" +#include "profiler_operation_registry.hpp" + +namespace { + +enum struct ConvLayout +{ + NHWC, // 0 +}; + +enum struct DataType +{ + F32_F32, // 0 + F16_F16, // 1 + BF16_BF16, // 2 + INT8_INT8, // 3 +}; + +#define OP_NAME "image_to_column" +#define OP_DESC "Image To Column" + +static void print_helper_msg() +{ + std::cout + // clang-format off + << "arg1: tensor operation (" OP_NAME ": " OP_DESC ")\n" + << "arg2: data type (0: Input fp32, Weight fp32, Output fp32\n" + << " 1: Input fp16, Weight fp16, Output fp16\n" + << " 2: Input bf16, Weight bf16, Output bf16\n" + << " 3: Input int8, Weight int8, Output int8)\n" + << "arg3: tensor layout (0: Input[N, Hi, Wi, C], Output[N * Ho * Wo, Y * X * C])\n" + << "arg4: verification (0: no, 1: yes)\n" + << "arg5: initialization (0: no init, 1: integer value, 2: decimal value)\n" + << "arg6: print tensor value (0: no; 1: yes)\n" + << "arg7: time kernel (0: no, 1: yes)\n" + << ck::utils::conv::get_conv_param_parser_helper_msg() << std::endl; + // clang-format on +} + +} // namespace + +int profile_image_to_column(int argc, char* argv[]) +{ + // 8 for control, 1 for num_dim_spatial + if(argc < 9) + { + print_helper_msg(); + return 1; + } + + const auto data_type = static_cast(std::stoi(argv[2])); + const auto layout = static_cast(std::stoi(argv[3])); + const bool do_verification = std::stoi(argv[4]); + const int init_method = std::stoi(argv[5]); + const bool do_log = std::stoi(argv[6]); + const bool time_kernel = std::stoi(argv[7]); + const int num_dim_spatial = std::stoi(argv[8]); + + // 8 for control, 1 for num_dim_spatial, 4 for G/N/K/C, and 6 * num_dim_spatial + if(argc != 8 + 1 + 4 + 6 * num_dim_spatial) + { + print_helper_msg(); + return 1; + } + + const auto params = ck::utils::conv::parse_conv_param(num_dim_spatial, 9, argv); + + using F32 = float; + using F16 = ck::half_t; + using BF16 = ck::bhalf_t; + using INT8 = int8_t; + + using namespace ck::tensor_layout::convolution; + + constexpr auto I1 = ck::Number<1>{}; + constexpr auto I2 = ck::Number<2>{}; + constexpr auto I3 = ck::Number<3>{}; + + auto profile = [&](auto num_dim_spatial_tmp, auto in_layout, auto in_type, auto out_type) { + constexpr ck::index_t NDimSpatial = num_dim_spatial_tmp.value; + + using InLayout = decltype(in_layout); + + using InDataType = decltype(in_type); + using OutDataType = decltype(out_type); + + bool pass = ck::profiler:: + profile_image_to_column_impl( + do_verification, init_method, do_log, time_kernel, params); + + return pass ? 0 : 1; + }; + + // NHWC + if(layout == ConvLayout::NHWC) + { + if(num_dim_spatial == 1) + { + if(data_type == DataType::F32_F32) + { + return profile(I1, GNWC{}, F32{}, F32{}); + } + else if(data_type == DataType::F16_F16) + { + return profile(I1, GNWC{}, F16{}, F16{}); + } + else if(data_type == DataType::BF16_BF16) + { + return profile(I1, GNWC{}, BF16{}, BF16{}); + } + else if(data_type == DataType::INT8_INT8) + { + return profile(I1, GNWC{}, INT8{}, INT8{}); + } + } + else if(num_dim_spatial == 2) + { + if(data_type == DataType::F32_F32) + { + return profile(I2, GNHWC{}, F32{}, F32{}); + } + else if(data_type == DataType::F16_F16) + { + return profile(I2, GNHWC{}, F16{}, F16{}); + } + else if(data_type == DataType::BF16_BF16) + { + return profile(I2, GNHWC{}, BF16{}, BF16{}); + } + else if(data_type == DataType::INT8_INT8) + { + return profile(I2, GNHWC{}, INT8{}, INT8{}); + } + } + else if(num_dim_spatial == 3) + { + if(data_type == DataType::F32_F32) + { + return profile(I3, GNDHWC{}, F32{}, F32{}); + } + else if(data_type == DataType::F16_F16) + { + return profile(I3, GNDHWC{}, F16{}, F16{}); + } + else if(data_type == DataType::BF16_BF16) + { + return profile(I3, GNDHWC{}, BF16{}, BF16{}); + } + else if(data_type == DataType::INT8_INT8) + { + return profile(I3, GNDHWC{}, INT8{}, INT8{}); + } + } + } + + std::cout << "this data_type & layout is not implemented" << std::endl; + + return 1; +} + +REGISTER_PROFILER_OPERATION(OP_NAME, OP_DESC, profile_image_to_column); diff --git a/profiler/src/profile_layernorm.cpp b/profiler/src/profile_layernorm.cpp index e93fc2dbd2bf004389eed968e0518a2cdb1ea4f6..7bf210e67837fa8fe0fe8982c40a1793344823a0 100644 --- a/profiler/src/profile_layernorm.cpp +++ b/profiler/src/profile_layernorm.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #include #include diff --git a/profiler/src/profile_max_pool3d_bwd.cpp b/profiler/src/profile_max_pool3d_bwd.cpp new file mode 100644 index 0000000000000000000000000000000000000000..45a64df42325611ec2a161586d6717f71d004f65 --- /dev/null +++ b/profiler/src/profile_max_pool3d_bwd.cpp @@ -0,0 +1,177 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#include +#include +#include + +#include "profiler/data_type_enum.hpp" +#include "profiler/profile_max_pool3d_bwd_impl.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "profiler_operation_registry.hpp" + +using ck::index_t; + +struct maxPoolbwdArgParser +{ + std::unordered_map> long_opts = {{"length", {}}, + {"wsize", {}}, + {"wstride", {}}, + {"wdilation", {}}, + {"pad1", {}}, + {"pad2", {}}}; + + bool parse_opt(int argc, char* argv[], const std::string& key, int i) + { + if(std::string("--") + key == argv[i]) + { + int pos = i; + while(++i < argc && argv[i][0] != '-') {} + int end = i; + for(int j = pos + 1; j < end; j++) + { + long_opts[key].push_back(std::stoi(argv[j])); + } + return true; + } + return false; + } + + void operator()(int argc, char* argv[]) + { + for(auto& kv : long_opts) + { + for(int i = 1; i < argc; i++) + { + if(parse_opt(argc, argv, kv.first, i)) + break; + } + } + } +}; + +void print_help_max_pool3d_bwd() +{ + std::cout << "arg1: data type (0: fp16; 1: fp32; 5: bf16)\n" + << "arg2: verification (0: no; 1: yes)\n" + << "arg3: initialization (0: no init; 1: integer value; 2: decimal value)\n" + << "arg4: print tensor value (0: no; 1: yes)\n" + << "arg5: time kernel (0=no, 1=yes)\n" + << "--length: input tensor length for NCDHW(e.g, --length 2 32 30 30 30) \n" + << "--wsize: window size for ZYX (e.g, --wsize 2 2 2) \n" + << "--wstride: window stride for DHW (e.g, --wstride 2 2 2) \n" + << "--wdilation: window dilation for DHW (e.g, --wdilation 1 1 1) \n" + << "--pad1: left side of padding in DHW (e.g, --pad1 1 1 1) \n" + << "--pad2: right side of padding in DHW (e.g, --pad2 1 1 1) \n" + << "eg: ckProfiler max_pool3d_bwd 0 1 2 0 1 --length 2 32 30 30 30 --wsize 2 2 2 " + "--wstride 2 2 2 --wdilation 1 1 1 --pad1 1 1 1 --pad2 1 1 1" + << std::endl; +} + +int profile_max_pool3d_bwd(int argc, char* argv[]) +{ + ck::DataTypeEnum data_type = ck::DataTypeEnum::Half; + bool do_verification = true; + int init_method = 0; + bool do_log = false; + bool time_kernel = true; + + std::vector in_length = {2, 32, 30, 30, 30}; + std::vector wsize = {2, 2, 2}; + std::vector wstride = {2, 2, 2}; + std::vector wdilation = {1, 1, 1}; + std::vector pad1 = {1, 1, 1}; + std::vector pad2 = {1, 1, 1}; + + if(argc != 2 && argc != 33) + { + print_help_max_pool3d_bwd(); + return 0; + } + else if(argc == 33) + { + data_type = static_cast(std::stoi(argv[2])); + do_verification = std::stoi(argv[3]); + init_method = std::stoi(argv[4]); + do_log = std::stoi(argv[5]); + time_kernel = std::stoi(argv[6]); + + // parse the long options + maxPoolbwdArgParser arg_parser; + arg_parser(argc, argv); + in_length = arg_parser.long_opts["length"]; + wsize = arg_parser.long_opts["wsize"]; + wstride = arg_parser.long_opts["wstride"]; + wdilation = arg_parser.long_opts["wdilation"]; + pad1 = arg_parser.long_opts["pad1"]; + pad2 = arg_parser.long_opts["pad2"]; + } + +#ifdef CK_ENABLE_FP16 + using F16 = ck::half_t; +#endif +#ifdef CK_ENABLE_BF16 + using BF16 = ck::bhalf_t; +#endif +#ifdef CK_ENABLE_FP32 + using F32 = float; +#endif + using I32 = int32_t; + + if(false) + ; +#ifdef CK_ENABLE_FP16 + else if(data_type == ck::DataTypeEnum::Half) + { + ck::profiler::profile_max_pool3d_bwd_impl(do_verification, + init_method, + do_log, + time_kernel, + in_length, + wsize, + wstride, + wdilation, + pad1, + pad2); + } +#endif +#ifdef CK_ENABLE_BF16 + else if(data_type == ck::DataTypeEnum::BFloat16) + { + ck::profiler::profile_max_pool3d_bwd_impl( + do_verification, + init_method, + do_log, + time_kernel, + in_length, + wsize, + wstride, + wdilation, + pad1, + pad2); + } +#endif +#ifdef CK_ENABLE_FP32 + else if(data_type == ck::DataTypeEnum::Float) + { + ck::profiler::profile_max_pool3d_bwd_impl(do_verification, + init_method, + do_log, + time_kernel, + in_length, + wsize, + wstride, + wdilation, + pad1, + pad2); + } +#endif + else + { + throw std::runtime_error("not implemented yet"); + } + + return 0; +} + +REGISTER_PROFILER_OPERATION("max_pool3d_bwd", "max_pool3d bwd", profile_max_pool3d_bwd); diff --git a/profiler/src/profile_max_pool3d_fwd.cpp b/profiler/src/profile_max_pool3d_fwd.cpp index 90c6e4e2b904d6f51ed8692fc359719364c9a88c..52fdf29fe4e4f48710074565e42c55dfbe756961 100644 --- a/profiler/src/profile_max_pool3d_fwd.cpp +++ b/profiler/src/profile_max_pool3d_fwd.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #include #include @@ -13,8 +13,12 @@ using ck::index_t; struct maxPoolFwdArgParser { - std::unordered_map> long_opts = { - {"length", {}}, {"wsize", {}}, {"wstride", {}}, {"pad1", {}}, {"pad2", {}}}; + std::unordered_map> long_opts = {{"length", {}}, + {"wsize", {}}, + {"wstride", {}}, + {"wdilation", {}}, + {"pad1", {}}, + {"pad2", {}}}; bool parse_opt(int argc, char* argv[], const std::string& key, int i) { @@ -47,7 +51,7 @@ struct maxPoolFwdArgParser void print_help_max_pool3d_fwd() { - std::cout << "arg1: data type (0: fp16; 1: fp32)\n" + std::cout << "arg1: data type (0: fp16; 1: fp32; 5: bf16)\n" << "arg2: verification (0: no; 1: yes)\n" << "arg3: initialization (0: no init; 1: integer value; 2: decimal value)\n" << "arg4: print tensor value (0: no; 1: yes)\n" @@ -56,10 +60,11 @@ void print_help_max_pool3d_fwd() << "--length: input tensor length for NCDHW(e.g, --length 2 32 30 30 30) \n" << "--wsize: window size for ZYX (e.g, --wsize 2 2 2) \n" << "--wstride: window stride for DHW (e.g, --wstride 2 2 2) \n" + << "--wdilation: window dilation for DHW (e.g, --wdilation 1 1 1) \n" << "--pad1: left side of padding in DHW (e.g, --pad1 1 1 1) \n" << "--pad2: right side of padding in DHW (e.g, --pad2 1 1 1) \n" << "eg: ckProfiler max_pool3d_fwd 0 1 2 0 1 0 --length 2 32 30 30 30 --wsize 2 2 2 " - "--wstride 2 2 2 --pad1 1 1 1 --pad2 1 1 1" + "--wstride 2 2 2 --wdilation 1 1 1 --pad1 1 1 1 --pad2 1 1 1" << std::endl; } @@ -75,15 +80,16 @@ int profile_max_pool3d_fwd(int argc, char* argv[]) std::vector in_length = {2, 32, 30, 30, 30}; std::vector wsize = {2, 2, 2}; std::vector wstride = {2, 2, 2}; + std::vector wdilation = {1, 1, 1}; std::vector pad1 = {1, 1, 1}; std::vector pad2 = {1, 1, 1}; - if(argc != 2 && argc != 30) + if(argc != 2 && argc != 34) { print_help_max_pool3d_fwd(); return 0; } - else if(argc == 30) + else if(argc == 34) { data_type = static_cast(std::stoi(argv[2])); do_verification = std::stoi(argv[3]); @@ -98,65 +104,136 @@ int profile_max_pool3d_fwd(int argc, char* argv[]) in_length = arg_parser.long_opts["length"]; wsize = arg_parser.long_opts["wsize"]; wstride = arg_parser.long_opts["wstride"]; + wdilation = arg_parser.long_opts["wdilation"]; pad1 = arg_parser.long_opts["pad1"]; pad2 = arg_parser.long_opts["pad2"]; } - using F16 = ck::half_t; - using F32 = float; - using I32 = int32_t; +#ifdef CK_ENABLE_FP16 + using F16 = ck::half_t; +#endif +#ifdef CK_ENABLE_BF16 + using BF16 = ck::bhalf_t; +#endif +#ifdef CK_ENABLE_FP32 + using F32 = float; +#endif + using I32 = int32_t; + using NDHWC = ck::tensor_layout::convolution::NDHWC; + +#if 1 constexpr auto ReduceOpId = ck::ReduceTensorOp::MAX; +#else + constexpr auto ReduceOpId = ck::ReduceTensorOp::AVG; +#endif - if(data_type == ck::DataTypeEnum::Half) + if(false) + ; +#ifdef CK_ENABLE_FP16 + else if(data_type == ck::DataTypeEnum::Half) + { + if(return_index) + ck::profiler:: + profile_pool3d_fwd_impl( + do_verification, + init_method, + do_log, + time_kernel, + in_length, + wsize, + wstride, + wdilation, + pad1, + pad2); + else + ck::profiler:: + profile_pool3d_fwd_impl( + do_verification, + init_method, + do_log, + time_kernel, + in_length, + wsize, + wstride, + wdilation, + pad1, + pad2); + } +#endif +#ifdef CK_ENABLE_BF16 + else if(data_type == ck::DataTypeEnum::BFloat16) { if(return_index) - ck::profiler::profile_pool3d_fwd_impl( - do_verification, - init_method, - do_log, - time_kernel, - in_length, - wsize, - wstride, - pad1, - pad2); + ck::profiler::profile_pool3d_fwd_impl(do_verification, + init_method, + do_log, + time_kernel, + in_length, + wsize, + wstride, + wdilation, + pad1, + pad2); else - ck::profiler::profile_pool3d_fwd_impl( - do_verification, - init_method, - do_log, - time_kernel, - in_length, - wsize, - wstride, - pad1, - pad2); + ck::profiler::profile_pool3d_fwd_impl(do_verification, + init_method, + do_log, + time_kernel, + in_length, + wsize, + wstride, + wdilation, + pad1, + pad2); } +#endif +#ifdef CK_ENABLE_FP32 else if(data_type == ck::DataTypeEnum::Float) { if(return_index) - ck::profiler::profile_pool3d_fwd_impl( - do_verification, - init_method, - do_log, - time_kernel, - in_length, - wsize, - wstride, - pad1, - pad2); + ck::profiler:: + profile_pool3d_fwd_impl( + do_verification, + init_method, + do_log, + time_kernel, + in_length, + wsize, + wstride, + wdilation, + pad1, + pad2); else - ck::profiler::profile_pool3d_fwd_impl( - do_verification, - init_method, - do_log, - time_kernel, - in_length, - wsize, - wstride, - pad1, - pad2); + ck::profiler:: + profile_pool3d_fwd_impl( + do_verification, + init_method, + do_log, + time_kernel, + in_length, + wsize, + wstride, + wdilation, + pad1, + pad2); } +#endif else { throw std::runtime_error("not implemented yet"); diff --git a/profiler/src/profile_reduce.cpp b/profiler/src/profile_reduce.cpp index 6925371858ee913dcb5d2b529c3e350a6eec5f4f..e4af5680a5d8a192edfc380d18fbe0a2c595cbbb 100644 --- a/profiler/src/profile_reduce.cpp +++ b/profiler/src/profile_reduce.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #include #include diff --git a/profiler/src/profile_softmax.cpp b/profiler/src/profile_softmax.cpp index 78b64dda7d70ec0a6cf8487562a9782c261f58c7..dfe8d95c904e297fcd2f7b83a166fc61c8950233 100644 --- a/profiler/src/profile_softmax.cpp +++ b/profiler/src/profile_softmax.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #include #include @@ -92,27 +92,76 @@ int profile_softmax(int argc, char* argv[]) { if(data_type == SoftmaxDataType::F16_F16) { - ck::profiler::profile_softmax_impl(do_verification, - init_method, - do_log, - time_kernel, - length, - stride, - reduce, - double(alpha), - double(beta)); + if(reduce.size() == 1) + ck::profiler::profile_softmax_impl( + do_verification, + init_method, + do_log, + time_kernel, + length, + stride, + reduce, + double(alpha), + double(beta)); + else if(reduce.size() == 2) + ck::profiler::profile_softmax_impl( + do_verification, + init_method, + do_log, + time_kernel, + length, + stride, + reduce, + double(alpha), + double(beta)); + else if(reduce.size() == 3) + ck::profiler::profile_softmax_impl( + do_verification, + init_method, + do_log, + time_kernel, + length, + stride, + reduce, + double(alpha), + double(beta)); + else + throw std::runtime_error("invalid number of dimensions to reduce"); } else if(data_type == SoftmaxDataType::F32_F32) { - ck::profiler::profile_softmax_impl(do_verification, - init_method, - do_log, - time_kernel, - length, - stride, - reduce, - double(alpha), - double(beta)); + if(reduce.size() == 1) + ck::profiler::profile_softmax_impl(do_verification, + init_method, + do_log, + time_kernel, + length, + stride, + reduce, + double(alpha), + double(beta)); + else if(reduce.size() == 2) + ck::profiler::profile_softmax_impl(do_verification, + init_method, + do_log, + time_kernel, + length, + stride, + reduce, + double(alpha), + double(beta)); + else if(reduce.size() == 3) + ck::profiler::profile_softmax_impl(do_verification, + init_method, + do_log, + time_kernel, + length, + stride, + reduce, + double(alpha), + double(beta)); + else + throw std::runtime_error("invalid number of dimensions to reduce"); } else { @@ -124,27 +173,97 @@ int profile_softmax(int argc, char* argv[]) { if(data_type == SoftmaxDataType::F16_F16) { - ck::profiler::profile_softmax_impl(do_verification, - init_method, - do_log, - time_kernel, - length, - stride, - reduce, - double(alpha), - double(beta)); + if(reduce.size() == 1) + ck::profiler::profile_softmax_impl( + do_verification, + init_method, + do_log, + time_kernel, + length, + stride, + reduce, + double(alpha), + double(beta)); + else if(reduce.size() == 2) + ck::profiler::profile_softmax_impl( + do_verification, + init_method, + do_log, + time_kernel, + length, + stride, + reduce, + double(alpha), + double(beta)); + else if(reduce.size() == 3) + ck::profiler::profile_softmax_impl( + do_verification, + init_method, + do_log, + time_kernel, + length, + stride, + reduce, + double(alpha), + double(beta)); + else if(reduce.size() == 4) + ck::profiler::profile_softmax_impl( + do_verification, + init_method, + do_log, + time_kernel, + length, + stride, + reduce, + double(alpha), + double(beta)); + else + throw std::runtime_error("invalid number of dimensions to reduce"); } else if(data_type == SoftmaxDataType::F32_F32) { - ck::profiler::profile_softmax_impl(do_verification, - init_method, - do_log, - time_kernel, - length, - stride, - reduce, - double(alpha), - double(beta)); + if(reduce.size() == 1) + ck::profiler::profile_softmax_impl(do_verification, + init_method, + do_log, + time_kernel, + length, + stride, + reduce, + double(alpha), + double(beta)); + else if(reduce.size() == 2) + ck::profiler::profile_softmax_impl(do_verification, + init_method, + do_log, + time_kernel, + length, + stride, + reduce, + double(alpha), + double(beta)); + else if(reduce.size() == 3) + ck::profiler::profile_softmax_impl(do_verification, + init_method, + do_log, + time_kernel, + length, + stride, + reduce, + double(alpha), + double(beta)); + else if(reduce.size() == 4) + ck::profiler::profile_softmax_impl(do_verification, + init_method, + do_log, + time_kernel, + length, + stride, + reduce, + double(alpha), + double(beta)); + else + throw std::runtime_error("invalid number of dimensions to reduce"); } else { diff --git a/profiler/src/profiler.cpp b/profiler/src/profiler.cpp index 080117e390c4df1bc14195649382267bcccc18f7..0f528c008f1d17afd42139d572d4f1cf1963fadd 100644 --- a/profiler/src/profiler.cpp +++ b/profiler/src/profiler.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #include #include diff --git a/profiler/src/profiler_operation_registry.hpp b/profiler/src/profiler_operation_registry.hpp index 91ff291233066527609133dc155e67388119385a..276b7b38dcd28ee935c108819cd0102b2bec4862 100644 --- a/profiler/src/profiler_operation_registry.hpp +++ b/profiler/src/profiler_operation_registry.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #include #include diff --git a/script/check_copyright_year.sh b/script/check_copyright_year.sh new file mode 100755 index 0000000000000000000000000000000000000000..f7709472efaa2f8bba4a5d7807b98a66f8caa807 --- /dev/null +++ b/script/check_copyright_year.sh @@ -0,0 +1,17 @@ +#!/bin/bash + +current_year=$(date +%Y) +exit_code=0 + +for file in $@; do + if grep -q "Copyright (c)" $file + then + if ! grep -q "Copyright (c).*$current_year" $file + then + echo "ERROR: File $file has a copyright notice without the current year ($current_year)." + exit_code=1 + fi + fi +done + +exit $exit_code diff --git a/script/clang-format-overwrite.sh b/script/clang-format-overwrite.sh index f9d11fcd8cb1144fc5e1091e848878e92878dbd2..da83254f00fef01532011bb47bd6a899849b3d86 100755 --- a/script/clang-format-overwrite.sh +++ b/script/clang-format-overwrite.sh @@ -1,2 +1,2 @@ -#find . -name deps -prune -o -name build -prune -o -iname '*.h' -o -iname '*.hpp' -o -iname '*.cpp' -o -iname '*.h.in' -o -iname '*.hpp.in' -o -iname '*.cpp.in' -o -iname '*.cl' -o -iname '*.cuh' -o -iname '*.cu' -o -iname '*.inc' | xargs -n 1 -P 16 -I{} -t sh -c 'clang-format-10 -i -style=file {}' -git status --porcelain | awk '$1 != "D" && (match($2, "\\.cpp|hpp|inc")) {print $2}' | xargs -n 1 -P 16 -I{} -t sh -c 'clang-format-10 -i -style=file {}' +#find . -name deps -prune -o -name build -prune -o -iname '*.h' -o -iname '*.hpp' -o -iname '*.cpp' -o -iname '*.h.in' -o -iname '*.hpp.in' -o -iname '*.cpp.in' -o -iname '*.cl' -o -iname '*.cuh' -o -iname '*.cu' -o -iname '*.inc' | xargs -n 1 -P 16 -I{} -t sh -c 'clang-format-12 -i -style=file {}' +git status --porcelain | awk '$1 != "D" && (match($2, "\\.cpp|hpp|inc")) {print $2}' | xargs -n 1 -P 16 -I{} -t sh -c 'clang-format-12 -i -style=file {}' diff --git a/script/cmake-ck-dev.sh b/script/cmake-ck-dev.sh index 426f68d4432e19f378c543e1edb90364f97e54af..da9e92b8622fe106746f0ddb4f04def8e6161760 100755 --- a/script/cmake-ck-dev.sh +++ b/script/cmake-ck-dev.sh @@ -16,4 +16,3 @@ cmake -D CMAKE_VERBOSE_MAKEFILE:BOOL=ON \ -D USE_BITINT_EXTENSION_INT4=OFF \ ${MY_PROJECT_SOURCE} - diff --git a/script/install_precommit.sh b/script/install_precommit.sh new file mode 100755 index 0000000000000000000000000000000000000000..296280bb0303802da1750382d65f17a65c9d07b2 --- /dev/null +++ b/script/install_precommit.sh @@ -0,0 +1,20 @@ +#!/bin/bash + +run_and_check() { + "$@" + status=$? + if [ $status -ne 0 ]; then + echo "Error with \"$@\": Exited with status $status" + exit $status + fi + return $status +} + +echo "I: Installing tools required for pre-commit checks..." +run_and_check apt install clang-format-12 + +echo "I: Installing pre-commit itself..." +run_and_check pip3 install pre-commit +run_and_check pre-commit install + +echo "I: Installation successful." diff --git a/script/profile_batched_gemm.sh b/script/profile_batched_gemm.sh index d19ddd0c6525dd0737187f904719ac4586e0722a..f90baaed685f9bfb83604801368f774c7ec71edb 100755 --- a/script/profile_batched_gemm.sh +++ b/script/profile_batched_gemm.sh @@ -3,13 +3,6 @@ ## GPU visibility export HIP_VISIBLE_DEVICES=0 DRIVER="../build/bin/ckProfiler" -OP=$1 -DATATYPE=$2 -LAYOUT=$3 -VERIFY=$4 -INIT=$5 -LOG=$6 -TIME=$7 OP=$1 DATATYPE=$2 diff --git a/script/uninstall_precommit.sh b/script/uninstall_precommit.sh new file mode 100755 index 0000000000000000000000000000000000000000..b0d4d15166fd85c711b0fb4f1bb87eb9ece720aa --- /dev/null +++ b/script/uninstall_precommit.sh @@ -0,0 +1 @@ +pre-commit uninstall diff --git a/test/CMakeLists.txt b/test/CMakeLists.txt index dad9b53cecc49ba7065c35eebb33af2577d9bc09..8fddd6085824d62ec930b66d9350bc94f6cceabd 100644 --- a/test/CMakeLists.txt +++ b/test/CMakeLists.txt @@ -57,7 +57,10 @@ add_subdirectory(data_type) add_subdirectory(elementwise_normalization) add_subdirectory(batchnorm) add_subdirectory(contraction) -add_subdirectory(pool_fwd) -if(GPU_TARGETS MATCHES "gfx1100") +add_subdirectory(pool) +add_subdirectory(batched_gemm_multi_d) +add_subdirectory(grouped_convnd_bwd_data) +add_subdirectory(image_to_column) +if(GPU_TARGETS MATCHES "gfx11") add_subdirectory(wmma_op) endif() diff --git a/test/batched_gemm/CMakeLists.txt b/test/batched_gemm/CMakeLists.txt index 5552a6c98e8166f1085afb2a9fafc003205908cd..8a3269e90f85e081e556619c900013cdebe959f3 100644 --- a/test/batched_gemm/CMakeLists.txt +++ b/test/batched_gemm/CMakeLists.txt @@ -1,17 +1,27 @@ -if(GPU_TARGETS MATCHES "gfx908" OR GPU_TARGETS MATCHES "gfx90a" OR GPU_TARGETS MATCHES "gfx940") - add_test_executable(test_batched_gemm_fp16 batched_gemm_fp16.cpp) - target_link_libraries(test_batched_gemm_fp16 PRIVATE utility) - target_link_libraries(test_batched_gemm_fp16 PRIVATE device_batched_gemm_instance) - - add_test_executable(test_batched_gemm_fp32 batched_gemm_fp32.cpp) - target_link_libraries(test_batched_gemm_fp32 PRIVATE utility) - target_link_libraries(test_batched_gemm_fp32 PRIVATE device_batched_gemm_instance) - - add_test_executable(test_batched_gemm_bf16 batched_gemm_bf16.cpp) - target_link_libraries(test_batched_gemm_bf16 PRIVATE utility) - target_link_libraries(test_batched_gemm_bf16 PRIVATE device_batched_gemm_instance) - - add_test_executable(test_batched_gemm_int8 batched_gemm_int8.cpp) - target_link_libraries(test_batched_gemm_int8 PRIVATE utility) - target_link_libraries(test_batched_gemm_int8 PRIVATE device_batched_gemm_instance) -endif() \ No newline at end of file +list(APPEND gpu_list gfx908 gfx90a gfx940 gfx941 gfx942) +set(target 0) +foreach(gpu IN LISTS GPU_TARGETS) + if(gpu IN_LIST gpu_list AND target EQUAL 0) + if(DTYPES MATCHES "fp16" OR NOT DEFINED DTYPES) + add_test_executable(test_batched_gemm_fp16 batched_gemm_fp16.cpp) + target_link_libraries(test_batched_gemm_fp16 PRIVATE utility) + target_link_libraries(test_batched_gemm_fp16 PRIVATE device_batched_gemm_instance) + endif() + if(DTYPES MATCHES "fp32" OR NOT DEFINED DTYPES) + add_test_executable(test_batched_gemm_fp32 batched_gemm_fp32.cpp) + target_link_libraries(test_batched_gemm_fp32 PRIVATE utility) + target_link_libraries(test_batched_gemm_fp32 PRIVATE device_batched_gemm_instance) + endif() + if(DTYPES MATCHES "bf16" OR NOT DEFINED DTYPES) + add_test_executable(test_batched_gemm_bf16 batched_gemm_bf16.cpp) + target_link_libraries(test_batched_gemm_bf16 PRIVATE utility) + target_link_libraries(test_batched_gemm_bf16 PRIVATE device_batched_gemm_instance) + endif() + if(DTYPES MATCHES "int8" OR NOT DEFINED DTYPES) + add_test_executable(test_batched_gemm_int8 batched_gemm_int8.cpp) + target_link_libraries(test_batched_gemm_int8 PRIVATE utility) + target_link_libraries(test_batched_gemm_int8 PRIVATE device_batched_gemm_instance) + endif() + set(target 1) + endif() +endforeach() \ No newline at end of file diff --git a/test/batched_gemm/batched_gemm_bf16.cpp b/test/batched_gemm/batched_gemm_bf16.cpp index 78be540627850426a075e2bde1516410ef070756..5d12a1e956d469d0a1be7301c993bfd981ebafc3 100644 --- a/test/batched_gemm/batched_gemm_bf16.cpp +++ b/test/batched_gemm/batched_gemm_bf16.cpp @@ -1,10 +1,12 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #include #include "profiler/profile_batched_gemm_impl.hpp" +#include "ck/library/tensor_operation_instance/gpu/batched_gemm.hpp" + namespace { using ADataType = ck::bhalf_t; using BDataType = ck::bhalf_t; @@ -12,6 +14,8 @@ using CDataType = ck::bhalf_t; using Row = ck::tensor_layout::gemm::RowMajor; using Col = ck::tensor_layout::gemm::ColumnMajor; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; } // namespace int main() @@ -23,21 +27,87 @@ int main() bool pass = true; - pass = pass && - ck::profiler::profile_batched_gemm_impl( - true, 1, false, 1, M, N, K, K, N, N, M * K, K * N, M * N, BatchCount); + using namespace ck::tensor_operation::device; + + pass = pass && ck::profiler::profile_batched_gemm_impl>( + true, 1, false, 1, M, N, K, K, N, N, M * K, K * N, M * N, BatchCount); - pass = pass && - ck::profiler::profile_batched_gemm_impl( - true, 1, false, 1, M, N, K, K, K, N, M * K, K * N, M * N, BatchCount); + pass = pass && ck::profiler::profile_batched_gemm_impl>( + true, 1, false, 1, M, N, K, K, K, N, M * K, K * N, M * N, BatchCount); - pass = pass && - ck::profiler::profile_batched_gemm_impl( - true, 1, false, 1, M, N, K, M, N, N, M * K, K * N, M * N, BatchCount); + pass = pass && ck::profiler::profile_batched_gemm_impl>( + true, 1, false, 1, M, N, K, M, N, N, M * K, K * N, M * N, BatchCount); - pass = pass && - ck::profiler::profile_batched_gemm_impl( - true, 1, false, 1, M, N, K, M, K, N, M * K, K * N, M * N, BatchCount); + pass = pass && ck::profiler::profile_batched_gemm_impl>( + true, 1, false, 1, M, N, K, M, K, N, M * K, K * N, M * N, BatchCount); std::cout << "test BatchedGEMM bf16: " << (pass ? "Pass" : "Fail") << std::endl; return pass ? 0 : 1; diff --git a/test/batched_gemm/batched_gemm_fp16.cpp b/test/batched_gemm/batched_gemm_fp16.cpp index 6cbbedf6774d2f90f85e5128c5bae8fbe912dc8e..a2b61d951a724bce18836c229ca7ef865336d849 100644 --- a/test/batched_gemm/batched_gemm_fp16.cpp +++ b/test/batched_gemm/batched_gemm_fp16.cpp @@ -1,10 +1,12 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #include #include "profiler/profile_batched_gemm_impl.hpp" +#include "ck/library/tensor_operation_instance/gpu/batched_gemm.hpp" + namespace { using ADataType = ck::half_t; using BDataType = ck::half_t; @@ -12,6 +14,8 @@ using CDataType = ck::half_t; using Row = ck::tensor_layout::gemm::RowMajor; using Col = ck::tensor_layout::gemm::ColumnMajor; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; } // namespace int main() @@ -23,21 +27,87 @@ int main() bool pass = true; - pass = pass && - ck::profiler::profile_batched_gemm_impl( - true, 1, false, 1, M, N, K, K, N, N, M * K, K * N, M * N, BatchCount); + using namespace ck::tensor_operation::device; + + pass = pass && ck::profiler::profile_batched_gemm_impl>( + true, 1, false, 1, M, N, K, K, N, N, M * K, K * N, M * N, BatchCount); - pass = pass && - ck::profiler::profile_batched_gemm_impl( - true, 1, false, 1, M, N, K, K, K, N, M * K, K * N, M * N, BatchCount); + pass = pass && ck::profiler::profile_batched_gemm_impl>( + true, 1, false, 1, M, N, K, K, K, N, M * K, K * N, M * N, BatchCount); - pass = pass && - ck::profiler::profile_batched_gemm_impl( - true, 1, false, 1, M, N, K, M, N, N, M * K, K * N, M * N, BatchCount); + pass = pass && ck::profiler::profile_batched_gemm_impl>( + true, 1, false, 1, M, N, K, M, N, N, M * K, K * N, M * N, BatchCount); - pass = pass && - ck::profiler::profile_batched_gemm_impl( - true, 1, false, 1, M, N, K, M, K, N, M * K, K * N, M * N, BatchCount); + pass = pass && ck::profiler::profile_batched_gemm_impl>( + true, 1, false, 1, M, N, K, M, K, N, M * K, K * N, M * N, BatchCount); std::cout << "test BatchedGEMM fp16: " << (pass ? "Pass" : "Fail") << std::endl; return pass ? 0 : 1; diff --git a/test/batched_gemm/batched_gemm_fp32.cpp b/test/batched_gemm/batched_gemm_fp32.cpp index c9e565e264b493d6f7f1b248fe6d307150a4c989..2b18d166e68f96e53670b4783038228b22b41cec 100644 --- a/test/batched_gemm/batched_gemm_fp32.cpp +++ b/test/batched_gemm/batched_gemm_fp32.cpp @@ -1,10 +1,12 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #include #include "profiler/profile_batched_gemm_impl.hpp" +#include "ck/library/tensor_operation_instance/gpu/batched_gemm.hpp" + namespace { using ADataType = float; using BDataType = float; @@ -12,6 +14,8 @@ using CDataType = float; using Row = ck::tensor_layout::gemm::RowMajor; using Col = ck::tensor_layout::gemm::ColumnMajor; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; } // namespace int main() @@ -23,21 +27,87 @@ int main() bool pass = true; - pass = pass && - ck::profiler::profile_batched_gemm_impl( - true, 1, false, 1, M, N, K, K, N, N, M * K, K * N, M * N, BatchCount); + using namespace ck::tensor_operation::device; + + pass = pass && ck::profiler::profile_batched_gemm_impl>( + true, 1, false, 1, M, N, K, K, N, N, M * K, K * N, M * N, BatchCount); - pass = pass && - ck::profiler::profile_batched_gemm_impl( - true, 1, false, 1, M, N, K, K, K, N, M * K, K * N, M * N, BatchCount); + pass = pass && ck::profiler::profile_batched_gemm_impl>( + true, 1, false, 1, M, N, K, K, K, N, M * K, K * N, M * N, BatchCount); - pass = pass && - ck::profiler::profile_batched_gemm_impl( - true, 1, false, 1, M, N, K, M, N, N, M * K, K * N, M * N, BatchCount); + pass = pass && ck::profiler::profile_batched_gemm_impl>( + true, 1, false, 1, M, N, K, M, N, N, M * K, K * N, M * N, BatchCount); - pass = pass && - ck::profiler::profile_batched_gemm_impl( - true, 1, false, 1, M, N, K, M, K, N, M * K, K * N, M * N, BatchCount); + pass = pass && ck::profiler::profile_batched_gemm_impl>( + true, 1, false, 1, M, N, K, M, K, N, M * K, K * N, M * N, BatchCount); std::cout << "test BatchedGEMM fp32: " << (pass ? "Pass" : "Fail") << std::endl; return pass ? 0 : 1; diff --git a/test/batched_gemm/batched_gemm_int8.cpp b/test/batched_gemm/batched_gemm_int8.cpp index 4da941a5766bc6decc5adc938b4ce39da5d8a6db..f607eaa84b996a0c28394647021953c5da6017a0 100644 --- a/test/batched_gemm/batched_gemm_int8.cpp +++ b/test/batched_gemm/batched_gemm_int8.cpp @@ -1,10 +1,12 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #include #include "profiler/profile_batched_gemm_impl.hpp" +#include "ck/library/tensor_operation_instance/gpu/batched_gemm.hpp" + namespace { using ADataType = int8_t; using BDataType = int8_t; @@ -12,6 +14,8 @@ using CDataType = int8_t; using Row = ck::tensor_layout::gemm::RowMajor; using Col = ck::tensor_layout::gemm::ColumnMajor; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; } // namespace int main() @@ -23,21 +27,87 @@ int main() bool pass = true; - pass = pass && - ck::profiler::profile_batched_gemm_impl( - true, 1, false, 1, M, N, K, K, N, N, M * K, K * N, M * N, BatchCount); + using namespace ck::tensor_operation::device; + + pass = pass && ck::profiler::profile_batched_gemm_impl>( + true, 1, false, 1, M, N, K, K, N, N, M * K, K * N, M * N, BatchCount); - pass = pass && - ck::profiler::profile_batched_gemm_impl( - true, 1, false, 1, M, N, K, K, K, N, M * K, K * N, M * N, BatchCount); + pass = pass && ck::profiler::profile_batched_gemm_impl>( + true, 1, false, 1, M, N, K, K, K, N, M * K, K * N, M * N, BatchCount); - pass = pass && - ck::profiler::profile_batched_gemm_impl( - true, 1, false, 1, M, N, K, M, N, N, M * K, K * N, M * N, BatchCount); + pass = pass && ck::profiler::profile_batched_gemm_impl>( + true, 1, false, 1, M, N, K, M, N, N, M * K, K * N, M * N, BatchCount); - pass = pass && - ck::profiler::profile_batched_gemm_impl( - true, 1, false, 1, M, N, K, M, K, N, M * K, K * N, M * N, BatchCount); + pass = pass && ck::profiler::profile_batched_gemm_impl>( + true, 1, false, 1, M, N, K, M, K, N, M * K, K * N, M * N, BatchCount); std::cout << "test BatchedGEMM int8: " << (pass ? "Pass" : "Fail") << std::endl; return pass ? 0 : 1; diff --git a/test/batched_gemm_gemm/CMakeLists.txt b/test/batched_gemm_gemm/CMakeLists.txt index eff3391750287d82e56de0c0f34d929554a03420..404e74f3bc2130abdae653e63ecea61c7389f52f 100644 --- a/test/batched_gemm_gemm/CMakeLists.txt +++ b/test/batched_gemm_gemm/CMakeLists.txt @@ -1,7 +1,13 @@ -if(GPU_TARGETS MATCHES "gfx908" OR GPU_TARGETS MATCHES "gfx90a" OR GPU_TARGETS MATCHES "gfx940") - add_custom_target(test_batched_gemm_gemm) - - add_gtest_executable(test_batched_gemm_gemm_fp16 test_batched_gemm_gemm_fp16.cpp) - target_link_libraries(test_batched_gemm_gemm_fp16 PRIVATE utility device_batched_gemm_gemm_instance) - add_dependencies(test_batched_gemm_gemm test_batched_gemm_gemm_fp16) -endif() \ No newline at end of file +list(APPEND gpu_list gfx908 gfx90a gfx940 gfx941 gfx942) +set(target 0) +foreach(gpu IN LISTS GPU_TARGETS) + if(gpu IN_LIST gpu_list AND target EQUAL 0) + if(DTYPES MATCHES "fp16" OR NOT DEFINED DTYPES) + add_custom_target(test_batched_gemm_gemm) + add_gtest_executable(test_batched_gemm_gemm_fp16 test_batched_gemm_gemm_fp16.cpp) + target_link_libraries(test_batched_gemm_gemm_fp16 PRIVATE utility device_batched_gemm_gemm_instance) + add_dependencies(test_batched_gemm_gemm test_batched_gemm_gemm_fp16) + set(target 1) + endif() + endif() +endforeach() \ No newline at end of file diff --git a/test/batched_gemm_gemm/test_batched_gemm_gemm_fp16.cpp b/test/batched_gemm_gemm/test_batched_gemm_gemm_fp16.cpp index aa113de219437e900b2de8739ddaffb2caf3cefa..1a8d5c2e55906fb77cfc3fef7d6c985c37eefefb 100644 --- a/test/batched_gemm_gemm/test_batched_gemm_gemm_fp16.cpp +++ b/test/batched_gemm_gemm/test_batched_gemm_gemm_fp16.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #include "gtest/gtest.h" #include "test_batched_gemm_gemm_util.hpp" diff --git a/test/batched_gemm_gemm/test_batched_gemm_gemm_util.hpp b/test/batched_gemm_gemm/test_batched_gemm_gemm_util.hpp index 53c4d37c44781be976f917362d0c433d1ec1711f..b0fffc466efd7de5f9a8055cb3f7e372bfad2ef7 100644 --- a/test/batched_gemm_gemm/test_batched_gemm_gemm_util.hpp +++ b/test/batched_gemm_gemm/test_batched_gemm_gemm_util.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #include diff --git a/test/batched_gemm_multi_d/CMakeLists.txt b/test/batched_gemm_multi_d/CMakeLists.txt new file mode 100644 index 0000000000000000000000000000000000000000..825f0dd235193f49b11d58ceb7b4484cc8c0f969 --- /dev/null +++ b/test/batched_gemm_multi_d/CMakeLists.txt @@ -0,0 +1,4 @@ +if(DL_KERNELS) + add_gtest_executable(test_batched_gemm_multi_d test_batched_gemm_multi_d.cpp) + target_link_libraries(test_batched_gemm_multi_d PRIVATE utility device_batched_gemm_multi_d_instance) +endif() diff --git a/test/batched_gemm_multi_d/test_batched_gemm_multi_d.cpp b/test/batched_gemm_multi_d/test_batched_gemm_multi_d.cpp new file mode 100644 index 0000000000000000000000000000000000000000..6c04086e0e746fdac801e8f85d80f1599b3695fe --- /dev/null +++ b/test/batched_gemm_multi_d/test_batched_gemm_multi_d.cpp @@ -0,0 +1,76 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#include +#include + +#include "profiler/profile_batched_gemm_impl.hpp" +#include "ck/library/tensor_operation_instance/gpu/batched_gemm_multi_d.hpp" + +namespace { +using F16 = ck::half_t; + +using Row = ck::tensor_layout::gemm::RowMajor; +using Col = ck::tensor_layout::gemm::ColumnMajor; + +using Empty_Tuple = ck::Tuple<>; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; + +template +class TestBatchedGemmMultiD : public ::testing::Test +{ + protected: + using ALayout = std::tuple_element_t<0, Tuple>; + using BLayout = std::tuple_element_t<1, Tuple>; + using CLayout = std::tuple_element_t<2, Tuple>; + + static constexpr int M = 512; + static constexpr int N = 256; + static constexpr int K = 128; + static constexpr int BatchCount = 3; + + template + void Run() + { + using namespace ck::tensor_operation::device; + + const bool pass = + ck::profiler::profile_batched_gemm_impl>( + true, 1, false, 1, M, N, K, K, N, N, M * K, K * N, M * N, BatchCount); + EXPECT_TRUE(pass); + } +}; + +using KernelTypes = ::testing::Types, + std::tuple, + std::tuple, + std::tuple>; +} // namespace + +TYPED_TEST_SUITE(TestBatchedGemmMultiD, KernelTypes); +#ifdef __fp16 +TYPED_TEST(TestBatchedGemmMultiD, f16) { this->template Run(); } +#endif +#ifdef CK_ENABLE_INT8 +TYPED_TEST(TestBatchedGemmMultiD, int8) { this->template Run(); } +#endif diff --git a/test/batched_gemm_reduce/CMakeLists.txt b/test/batched_gemm_reduce/CMakeLists.txt index 1a8d7112a344a8a98a668dcdb8ff2afef76caa83..af95a50eabae20b36cfb69add37d18f7f74bcff9 100644 --- a/test/batched_gemm_reduce/CMakeLists.txt +++ b/test/batched_gemm_reduce/CMakeLists.txt @@ -1,5 +1,12 @@ -if(GPU_TARGETS MATCHES "gfx908" OR GPU_TARGETS MATCHES "gfx90a" OR GPU_TARGETS MATCHES "gfx940") - add_test_executable(test_batched_gemm_reduce_fp16 batched_gemm_reduce_fp16.cpp) - target_link_libraries(test_batched_gemm_reduce_fp16 PRIVATE utility) - target_link_libraries(test_batched_gemm_reduce_fp16 PRIVATE device_batched_gemm_reduce_instance) -endif() +list(APPEND gpu_list gfx908 gfx90a gfx940 gfx941 gfx942) +set(target 0) +foreach(gpu IN LISTS GPU_TARGETS) + if(gpu IN_LIST gpu_list AND target EQUAL 0) + if(DTYPES MATCHES "fp16" OR NOT DEFINED DTYPES) + add_test_executable(test_batched_gemm_reduce_fp16 batched_gemm_reduce_fp16.cpp) + target_link_libraries(test_batched_gemm_reduce_fp16 PRIVATE utility) + target_link_libraries(test_batched_gemm_reduce_fp16 PRIVATE device_batched_gemm_reduce_instance) + set(target 1) + endif() + endif() +endforeach() diff --git a/test/batched_gemm_reduce/batched_gemm_reduce_fp16.cpp b/test/batched_gemm_reduce/batched_gemm_reduce_fp16.cpp index b150ce50d166d5a673c6b34c468b7988f3b4d91f..dd2638ce899bee16accfcfb85c402909b2c2a05c 100644 --- a/test/batched_gemm_reduce/batched_gemm_reduce_fp16.cpp +++ b/test/batched_gemm_reduce/batched_gemm_reduce_fp16.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #include diff --git a/test/batched_gemm_softmax_gemm/CMakeLists.txt b/test/batched_gemm_softmax_gemm/CMakeLists.txt index dee0bb56ae9badd74ea7e888e68b8812dfdeb889..c49175a2eb1b6be9da1aeeb763c551af67e15c08 100644 --- a/test/batched_gemm_softmax_gemm/CMakeLists.txt +++ b/test/batched_gemm_softmax_gemm/CMakeLists.txt @@ -1,7 +1,13 @@ -if(GPU_TARGETS MATCHES "gfx908" OR GPU_TARGETS MATCHES "gfx90a" OR GPU_TARGETS MATCHES "gfx940") - add_custom_target(test_batched_gemm_softmax_gemm) - - add_gtest_executable(test_batched_gemm_softmax_gemm_fp16 test_batched_gemm_softmax_gemm_fp16.cpp) - target_link_libraries(test_batched_gemm_softmax_gemm_fp16 PRIVATE utility device_batched_gemm_softmax_gemm_instance) - add_dependencies(test_batched_gemm_softmax_gemm test_batched_gemm_softmax_gemm_fp16) -endif() \ No newline at end of file +list(APPEND gpu_list gfx908 gfx90a gfx940 gfx941 gfx942) +set(target 0) +foreach(gpu IN LISTS GPU_TARGETS) + if(gpu IN_LIST gpu_list AND target EQUAL 0) + if(DTYPES MATCHES "fp16" OR NOT DEFINED DTYPES) + add_custom_target(test_batched_gemm_softmax_gemm) + add_gtest_executable(test_batched_gemm_softmax_gemm_fp16 test_batched_gemm_softmax_gemm_fp16.cpp) + target_link_libraries(test_batched_gemm_softmax_gemm_fp16 PRIVATE utility device_batched_gemm_softmax_gemm_instance) + add_dependencies(test_batched_gemm_softmax_gemm test_batched_gemm_softmax_gemm_fp16) + set(target 1) + endif() + endif() +endforeach() \ No newline at end of file diff --git a/test/batched_gemm_softmax_gemm/test_batched_gemm_softmax_gemm_fp16.cpp b/test/batched_gemm_softmax_gemm/test_batched_gemm_softmax_gemm_fp16.cpp index 5df7769d5f675e906e350738751923d26b6af39d..cb46a995c61c4dc15c9fcbd9c2ca07b850f90069 100644 --- a/test/batched_gemm_softmax_gemm/test_batched_gemm_softmax_gemm_fp16.cpp +++ b/test/batched_gemm_softmax_gemm/test_batched_gemm_softmax_gemm_fp16.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #include "gtest/gtest.h" #include "test_batched_gemm_softmax_gemm_util.hpp" diff --git a/test/batched_gemm_softmax_gemm/test_batched_gemm_softmax_gemm_util.hpp b/test/batched_gemm_softmax_gemm/test_batched_gemm_softmax_gemm_util.hpp index 98debe19c3c0b419f46dd08321f43f846536772a..d8ee744c6098843e9e78e21ac59a5f3733f40aa8 100644 --- a/test/batched_gemm_softmax_gemm/test_batched_gemm_softmax_gemm_util.hpp +++ b/test/batched_gemm_softmax_gemm/test_batched_gemm_softmax_gemm_util.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #include diff --git a/test/batched_gemm_softmax_gemm_permute/CMakeLists.txt b/test/batched_gemm_softmax_gemm_permute/CMakeLists.txt index cb68af06c679cb058415b0ad62099a46fef6cf69..ae28daa80e4eceed3a85b39d92e7ceccd17caeb1 100644 --- a/test/batched_gemm_softmax_gemm_permute/CMakeLists.txt +++ b/test/batched_gemm_softmax_gemm_permute/CMakeLists.txt @@ -1,17 +1,26 @@ -if(GPU_TARGETS MATCHES "gfx908" OR GPU_TARGETS MATCHES "gfx90a" OR GPU_TARGETS MATCHES "gfx940") - add_custom_target(test_batched_gemm_softmax_gemm_permute) - - add_gtest_executable(test_batched_gemm_softmax_gemm_permute_fp16 test_batched_gemm_softmax_gemm_permute_fp16.cpp) - add_gtest_executable(test_batched_gemm_softmax_gemm_permute_bf16 test_batched_gemm_softmax_gemm_permute_bf16.cpp) - target_link_libraries(test_batched_gemm_softmax_gemm_permute_fp16 PRIVATE utility device_batched_gemm_softmax_gemm_permute_instance) - target_link_libraries(test_batched_gemm_softmax_gemm_permute_bf16 PRIVATE utility device_batched_gemm_softmax_gemm_permute_instance) - add_dependencies(test_batched_gemm_softmax_gemm_permute test_batched_gemm_softmax_gemm_permute_fp16) - add_dependencies(test_batched_gemm_softmax_gemm_permute test_batched_gemm_softmax_gemm_permute_bf16) - - add_gtest_executable(test_batched_gemm_bias_softmax_gemm_permute_fp16 test_batched_gemm_bias_softmax_gemm_permute_fp16.cpp) - add_gtest_executable(test_batched_gemm_bias_softmax_gemm_permute_bf16 test_batched_gemm_bias_softmax_gemm_permute_bf16.cpp) - target_link_libraries(test_batched_gemm_bias_softmax_gemm_permute_fp16 PRIVATE utility device_batched_gemm_softmax_gemm_permute_instance) - target_link_libraries(test_batched_gemm_bias_softmax_gemm_permute_bf16 PRIVATE utility device_batched_gemm_softmax_gemm_permute_instance) - add_dependencies(test_batched_gemm_softmax_gemm_permute test_batched_gemm_bias_softmax_gemm_permute_fp16) - add_dependencies(test_batched_gemm_softmax_gemm_permute test_batched_gemm_bias_softmax_gemm_permute_bf16) -endif() \ No newline at end of file +list(APPEND gpu_list gfx908 gfx90a gfx940 gfx941 gfx942) +set(target 0) +foreach(gpu IN LISTS GPU_TARGETS) + if(gpu IN_LIST gpu_list AND target EQUAL 0) + if(DTYPES MATCHES "fp16" OR DTYPES MATCHES "bf16" OR NOT DEFINED DTYPES) + add_custom_target(test_batched_gemm_softmax_gemm_permute) + endif() + if(DTYPES MATCHES "fp16" OR NOT DEFINED DTYPES) + add_gtest_executable(test_batched_gemm_softmax_gemm_permute_fp16 test_batched_gemm_softmax_gemm_permute_fp16.cpp) + add_gtest_executable(test_batched_gemm_bias_softmax_gemm_permute_fp16 test_batched_gemm_bias_softmax_gemm_permute_fp16.cpp) + target_link_libraries(test_batched_gemm_softmax_gemm_permute_fp16 PRIVATE utility device_batched_gemm_softmax_gemm_permute_instance) + target_link_libraries(test_batched_gemm_bias_softmax_gemm_permute_fp16 PRIVATE utility device_batched_gemm_softmax_gemm_permute_instance) + add_dependencies(test_batched_gemm_softmax_gemm_permute test_batched_gemm_softmax_gemm_permute_fp16) + add_dependencies(test_batched_gemm_softmax_gemm_permute test_batched_gemm_bias_softmax_gemm_permute_fp16) + endif() + if(DTYPES MATCHES "bf16" OR NOT DEFINED DTYPES) + add_gtest_executable(test_batched_gemm_softmax_gemm_permute_bf16 test_batched_gemm_softmax_gemm_permute_bf16.cpp) + add_gtest_executable(test_batched_gemm_bias_softmax_gemm_permute_bf16 test_batched_gemm_bias_softmax_gemm_permute_bf16.cpp) + target_link_libraries(test_batched_gemm_softmax_gemm_permute_bf16 PRIVATE utility device_batched_gemm_softmax_gemm_permute_instance) + target_link_libraries(test_batched_gemm_bias_softmax_gemm_permute_bf16 PRIVATE utility device_batched_gemm_softmax_gemm_permute_instance) + add_dependencies(test_batched_gemm_softmax_gemm_permute test_batched_gemm_softmax_gemm_permute_bf16) + add_dependencies(test_batched_gemm_softmax_gemm_permute test_batched_gemm_bias_softmax_gemm_permute_bf16) + endif() + set(target 1) + endif() +endforeach() \ No newline at end of file diff --git a/test/batched_gemm_softmax_gemm_permute/test_batched_gemm_bias_softmax_gemm_permute_bf16.cpp b/test/batched_gemm_softmax_gemm_permute/test_batched_gemm_bias_softmax_gemm_permute_bf16.cpp index fe65a6fb9681a2a336132cf082cd9424e988e1ed..ef88ce6d81845d9ddb61fc00159899f0d11daa5f 100644 --- a/test/batched_gemm_softmax_gemm_permute/test_batched_gemm_bias_softmax_gemm_permute_bf16.cpp +++ b/test/batched_gemm_softmax_gemm_permute/test_batched_gemm_bias_softmax_gemm_permute_bf16.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #include "gtest/gtest.h" #include "test_batched_gemm_bias_softmax_gemm_permute_util.hpp" diff --git a/test/batched_gemm_softmax_gemm_permute/test_batched_gemm_bias_softmax_gemm_permute_fp16.cpp b/test/batched_gemm_softmax_gemm_permute/test_batched_gemm_bias_softmax_gemm_permute_fp16.cpp index 7235cd1b0b669cb1cfad6753289a94ad6a569df3..b38b10d1953091bc7d0739a38fa184d3d51a94f1 100644 --- a/test/batched_gemm_softmax_gemm_permute/test_batched_gemm_bias_softmax_gemm_permute_fp16.cpp +++ b/test/batched_gemm_softmax_gemm_permute/test_batched_gemm_bias_softmax_gemm_permute_fp16.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #include "gtest/gtest.h" #include "test_batched_gemm_softmax_gemm_permute_util.hpp" diff --git a/test/batched_gemm_softmax_gemm_permute/test_batched_gemm_bias_softmax_gemm_permute_util.hpp b/test/batched_gemm_softmax_gemm_permute/test_batched_gemm_bias_softmax_gemm_permute_util.hpp index af5f0efec3853eae68dfbb91f97930a1e8e5727b..d7c39367c8a19290718b315251581f4f3f626add 100644 --- a/test/batched_gemm_softmax_gemm_permute/test_batched_gemm_bias_softmax_gemm_permute_util.hpp +++ b/test/batched_gemm_softmax_gemm_permute/test_batched_gemm_bias_softmax_gemm_permute_util.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #include diff --git a/test/batched_gemm_softmax_gemm_permute/test_batched_gemm_softmax_gemm_permute_bf16.cpp b/test/batched_gemm_softmax_gemm_permute/test_batched_gemm_softmax_gemm_permute_bf16.cpp index defe36124056d602f565a20333379fc86a61b6aa..8e0baede113943c0bcb7d6f929217b2be94f0073 100644 --- a/test/batched_gemm_softmax_gemm_permute/test_batched_gemm_softmax_gemm_permute_bf16.cpp +++ b/test/batched_gemm_softmax_gemm_permute/test_batched_gemm_softmax_gemm_permute_bf16.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #include "gtest/gtest.h" #include "test_batched_gemm_softmax_gemm_permute_util.hpp" diff --git a/test/batched_gemm_softmax_gemm_permute/test_batched_gemm_softmax_gemm_permute_fp16.cpp b/test/batched_gemm_softmax_gemm_permute/test_batched_gemm_softmax_gemm_permute_fp16.cpp index 293acd60155ace98eaea89a11eb03fa6710fd877..81d404109fddc4321d870d1c3b8c0a2fe309c9c4 100644 --- a/test/batched_gemm_softmax_gemm_permute/test_batched_gemm_softmax_gemm_permute_fp16.cpp +++ b/test/batched_gemm_softmax_gemm_permute/test_batched_gemm_softmax_gemm_permute_fp16.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #include "gtest/gtest.h" #include "test_batched_gemm_softmax_gemm_permute_util.hpp" diff --git a/test/batched_gemm_softmax_gemm_permute/test_batched_gemm_softmax_gemm_permute_util.hpp b/test/batched_gemm_softmax_gemm_permute/test_batched_gemm_softmax_gemm_permute_util.hpp index 912bbc91edd51d06c827fc31495d2955bea85a2b..9df03ffd2a11211600872957e9686e105bf8b2f2 100644 --- a/test/batched_gemm_softmax_gemm_permute/test_batched_gemm_softmax_gemm_permute_util.hpp +++ b/test/batched_gemm_softmax_gemm_permute/test_batched_gemm_softmax_gemm_permute_util.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #include diff --git a/test/batchnorm/batchnorm_bwd_rank_4.cpp b/test/batchnorm/batchnorm_bwd_rank_4.cpp index caa7331ea2c3e6dfe57a41272c5f4fba56662e39..a4696cf2a39a774c86d4a68f1a0edf0ce31a3f58 100644 --- a/test/batchnorm/batchnorm_bwd_rank_4.cpp +++ b/test/batchnorm/batchnorm_bwd_rank_4.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #include #include diff --git a/test/batchnorm/batchnorm_fwd_rank_4.cpp b/test/batchnorm/batchnorm_fwd_rank_4.cpp index 13aef7d6bfc9f222e663b951a2f9d44e81fc93bc..9b6fbd0f662648426aad3f9c316604510ff6e1b1 100644 --- a/test/batchnorm/batchnorm_fwd_rank_4.cpp +++ b/test/batchnorm/batchnorm_fwd_rank_4.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #include #include diff --git a/test/batchnorm/batchnorm_infer_rank_4.cpp b/test/batchnorm/batchnorm_infer_rank_4.cpp index 77fc1daae6130e415e88fc6f69966a82280acd3a..ecb4043b361b058ea82b91d98232eadb694ead33 100644 --- a/test/batchnorm/batchnorm_infer_rank_4.cpp +++ b/test/batchnorm/batchnorm_infer_rank_4.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #include #include diff --git a/test/block_swizzle_test/block_swizzle_test.cpp b/test/block_swizzle_test/block_swizzle_test.cpp new file mode 100644 index 0000000000000000000000000000000000000000..29e118c2adf0a4d5955d083aa62f56c9af3ed0a4 --- /dev/null +++ b/test/block_swizzle_test/block_swizzle_test.cpp @@ -0,0 +1,406 @@ +#include +#include +#include +#include +#include +#include "simple_args.h" + +simple_args_t create_arg(int argc, char** argv) +{ + simple_args_t args; + args.insert("m", "1024", "matrix m") + .insert("n", "1024", "matrix n") + .insert("k", "1024", "matrix k") + .insert("m_per_block", "128", "m_per_block") + .insert("n_per_block", "128", "n_per_block") + .insert("k_per_block", "32", "k_per_block") + .insert("num_cu", "104", "num cu") + .insert("occupancy", "2", "occupancy") + .parse(argc, argv); + return args; +} + +namespace impl { +template +T integer_divide_ceil(T n, T d) +{ + return (n + d - 1) / d; +} + +template +T min(T a, T b) +{ + return a > b ? b : a; +} + +template +T max(T a, T b) +{ + return a > b ? a : b; +} + +} // namespace impl + +struct block_dispatcher_t +{ + public: + uint32_t m_per_block; + uint32_t n_per_block; + uint32_t k_per_block; + uint32_t num_cu; + uint32_t occupancy; + uint32_t m; + uint32_t n; + uint32_t k; + + //-------------------------------------- + + uint32_t sk_num_blocks; + uint32_t sk_num_big_blocks; + uint32_t sk_total_iters; + + // uint32_t sk_num_blocks_per_tile; // how many + + uint32_t dp_start_block_idx; + uint32_t dp_iters_per_block; + uint32_t dp_num_blocks; + + uint32_t k_iters_per_tile; + uint32_t k_iters_per_big_block; + //-------------------------------------- + + static constexpr uint32_t min_k_iters_per_sk_block = 1; + + void dump() + { + printf("%dx%dx%d(%dx%dx%d), cu:%d, occ:%d, grids:%d, sk_num_big_blocks:%d, " + "sk_num_blocks:%d, sk_total_iters:%d, dp_start_block_idx:%d, dp_iters_per_block:%d, " + "dp_num_blocks:%d, k_iters_per_tile:%d, k_iters_per_big_block:%d\n", + m, + n, + k, + m_per_block, + n_per_block, + k_per_block, + num_cu, + occupancy, + get_grid_dims_x(), + sk_num_big_blocks, + sk_num_blocks, + sk_total_iters, + dp_start_block_idx, + dp_iters_per_block, + dp_num_blocks, + k_iters_per_tile, + k_iters_per_big_block); + } + + block_dispatcher_t(uint32_t m_per_block_, + uint32_t n_per_block_, + uint32_t k_per_block_, + uint32_t num_cu_, + uint32_t occupancy_, + uint32_t m_, + uint32_t n_, + uint32_t k_) + : m_per_block(m_per_block_), + n_per_block(n_per_block_), + k_per_block(k_per_block_), + num_cu(num_cu_), + occupancy(occupancy_), + m(m_), + n(n_), + k(k_) + { + init(); + } + + uint32_t get_grid_dims_x() { return dp_start_block_idx + dp_num_blocks; } + + uint32_t get_block_idx(uint32_t bid) + { + // block id is linearily allocated along sk blocks (dp blocks are fine) + // this function will compute blockIdx.x and the linear sk block mapping + // uint32_t block_idx = 0; + // if(bid < sk_num_big_blocks) { + // uint32_t current_k_iter = bid * k_iters_per_big_block; + // tile_idx = current_k_iter / k_iters_per_tile; + // } + return bid; + } + + uint32_t get_current_itr(uint32_t block_idx) + { + uint32_t current_itr = 0; + if(block_idx < sk_num_big_blocks) + { + current_itr = block_idx * k_iters_per_big_block; + } + else if(block_idx < sk_num_blocks) + { + current_itr = (sk_num_big_blocks * k_iters_per_big_block) + + (block_idx - sk_num_big_blocks) * (k_iters_per_big_block - 1); + } + else if(block_idx >= dp_start_block_idx) + { + current_itr = sk_total_iters + (block_idx - dp_start_block_idx) * dp_iters_per_block; + } + return current_itr; + } + + void get_block_itr(uint32_t block_idx, uint32_t& iter_start, uint32_t& iter_end) + { + if(block_idx < sk_num_big_blocks) + { + iter_start = block_idx * k_iters_per_big_block; + iter_end = iter_start + k_iters_per_big_block; + } + else if(block_idx < sk_num_blocks) + { + iter_start = (sk_num_big_blocks * k_iters_per_big_block) + + (block_idx - sk_num_big_blocks) * (k_iters_per_big_block - 1); + iter_end = iter_start + (k_iters_per_big_block - 1); + } + else if(block_idx >= dp_start_block_idx) + { + iter_start = sk_total_iters + (block_idx - dp_start_block_idx) * dp_iters_per_block; + iter_end = iter_start + dp_iters_per_block; + } + } + + private: + void init() + { + uint32_t num_tiles = + impl::integer_divide_ceil(m, m_per_block) * impl::integer_divide_ceil(n, n_per_block); + k_iters_per_tile = impl::integer_divide_ceil(k, k_per_block); + + // one cu can hold one wg at one time, from the whole chip's point of view + // if number of wg is same as num_cu, we call it 1 dispatch + // if number of wg is 2x num_cu, we call it 2 dispatches. + // one dispatch can deliever wg same as num_cu (full dispatch), or less than num_cu (partial + // dispatch) + // + uint32_t full_dispatches = num_tiles / num_cu; + uint32_t full_dispatch_tiles = full_dispatches * num_cu; + uint32_t partial_dispatche_tiles = num_tiles - full_dispatch_tiles; + + uint32_t sk_occupancy = occupancy; + uint32_t dp_tiles = full_dispatch_tiles; + uint32_t sk_tiles = partial_dispatche_tiles; + + if(full_dispatches < occupancy) + { + // in this case, we allocate all blocks as sk blocks + // sk_occupancy = occupancy - full_dispatches; + sk_occupancy = 1; // TODO: single occ seems better + dp_tiles = full_dispatch_tiles; + sk_tiles = partial_dispatche_tiles; + } + else if((occupancy > 1) && (full_dispatches % occupancy == occupancy - 1)) + { + // e.g. occupancy = 2, full_dispatches = 3, 5, 7 ... + // occupancy = 3, full_dispatches = 5, 8, 11 ... + // occupancy = 4, full_dispatches = 7, 11 ... + sk_occupancy = 1; // left 1 slot for sk occupancy + dp_tiles = full_dispatch_tiles; + sk_tiles = partial_dispatche_tiles; + } + else + { + // others, we reduce 1 dispatch from dp, together with partial dispatch, + // to construct sk dispatch + sk_occupancy = occupancy - ((full_dispatches - 1) % occupancy); + dp_tiles = full_dispatch_tiles - num_cu; + sk_tiles = partial_dispatche_tiles + num_cu; + } + + // dp_num_blocks = dp_tiles; + // dp_start_block_idx = num_cu * sk_occupancy; + dp_iters_per_block = k_iters_per_tile; + + sk_total_iters = k_iters_per_tile * sk_tiles; + + // printf("num_tiles:%d, full_dispatches:%d, full_dispatch_tiles:%d, + // partial_dispatche_tiles:%d\n", + // num_tiles, full_dispatches, full_dispatch_tiles, partial_dispatche_tiles); + + { + uint32_t min_sk_tiles = (sk_tiles >= num_cu) ? num_cu : (sk_tiles + 1); + uint32_t max_sk_tiles = + (sk_tiles >= num_cu) ? num_cu * sk_occupancy + : impl::min(num_cu, sk_total_iters / min_k_iters_per_sk_block); + + // if use dp for sk-block, how many iters do we need + uint32_t dp_for_sk_iters = k_iters_per_tile; + + uint32_t best_sk_score = + std::numeric_limits::max(); // we need to find the smallest sk iters + for(uint32_t tentative_sk_blocks = min_sk_tiles; tentative_sk_blocks < max_sk_tiles; + tentative_sk_blocks++) + { + uint32_t tentative_sk_iters_per_block = + (sk_total_iters + tentative_sk_blocks - 1) / tentative_sk_blocks; + uint32_t tentative_sk_iters = tentative_sk_iters_per_block; + uint32_t sk_blocks_per_tile = (tentative_sk_blocks + sk_tiles - 1) / sk_tiles; + + // TODO: carefully adjust this parameter + // the more sk_blocks_per_tile, the worse the overhead + uint32_t cross_sk_blocks_overhead = sk_blocks_per_tile; + if(tentative_sk_blocks % sk_tiles != 0) + { + // penalty for uneven divide + cross_sk_blocks_overhead += + sk_blocks_per_tile * tentative_sk_iters_per_block / 50; + } + + uint32_t tentative_sk_score = tentative_sk_iters + cross_sk_blocks_overhead; + + if(tentative_sk_score < best_sk_score) + { + best_sk_score = tentative_sk_score; + sk_num_blocks = tentative_sk_blocks; + } + } + + if(best_sk_score >= dp_for_sk_iters) + { + sk_num_blocks = 0; + } + + if(sk_num_blocks == 0) + { + sk_num_big_blocks = 0; + k_iters_per_big_block = 0; + + dp_num_blocks = num_tiles; // all tile to be dp block + dp_start_block_idx = 0; + sk_total_iters = 0; // clear this tiles + } + else + { + uint32_t k_iters_per_sk_block = sk_total_iters / sk_num_blocks; + sk_num_big_blocks = sk_total_iters - k_iters_per_sk_block * sk_num_blocks; + k_iters_per_big_block = k_iters_per_sk_block + 1; + + dp_num_blocks = dp_tiles; + dp_start_block_idx = (sk_num_blocks + num_cu - 1) / num_cu * num_cu; + } + } + } +}; + +struct tile_work_t +{ + uint32_t tile_idx; + uint32_t iter_begin; + uint32_t k_begin; + uint32_t k_end; + uint32_t k_iters_remaining; +}; + +int main(int argc, char** argv) +{ + simple_args_t arg = create_arg(argc, argv); + block_dispatcher_t block_dispatcher{arg.get_uint32("m_per_block"), + arg.get_uint32("n_per_block"), + arg.get_uint32("k_per_block"), + arg.get_uint32("num_cu"), + arg.get_uint32("occupancy"), + arg.get_uint32("m"), + arg.get_uint32("n"), + arg.get_uint32("k")}; + block_dispatcher.dump(); + // simulate actual kernel launch + uint32_t dim_x = block_dispatcher.get_grid_dims_x(); + uint32_t total_k_iters = + impl::integer_divide_ceil(arg.get_uint32("k"), arg.get_uint32("k_per_block")); + uint32_t num_tiles = + impl::integer_divide_ceil(arg.get_uint32("m"), arg.get_uint32("m_per_block")) * + impl::integer_divide_ceil(arg.get_uint32("n"), arg.get_uint32("n_per_block")); + + std::vector valid_tile_record(num_tiles * total_k_iters); + + for(uint32_t bid = 0; bid < dim_x; bid++) + { + uint32_t block_idx = block_dispatcher.get_block_idx(bid); + bool is_sk_block = block_idx < (block_dispatcher.sk_num_blocks); + bool is_dp_block = block_idx >= block_dispatcher.dp_start_block_idx; + uint32_t iter_start, iter_end; + block_dispatcher.get_block_itr(block_idx, iter_start, iter_end); + uint32_t total_iter_length = iter_end - iter_start; + + while(true) + { + uint32_t iter_length_mod = iter_end % block_dispatcher.k_iters_per_tile; + uint32_t current_iter_length = + impl::min(iter_length_mod == 0 ? (iter_end - iter_start) : iter_length_mod, + total_iter_length); + uint32_t tile_idx = (iter_end - 1) / block_dispatcher.k_iters_per_tile; + uint32_t tile_iter_start = + ((iter_end - 1) % block_dispatcher.k_iters_per_tile) - current_iter_length + 1; + + if(is_sk_block) + { + printf("[sk_block] bid:%3d, block_idx:%3d, tile_idx:%3d, iter_start:%d(%d | %d), " + "iter_end:%d (len:%d)\n", + bid, + block_idx, + tile_idx, + iter_end - current_iter_length, + tile_iter_start, + iter_start, + iter_end, + current_iter_length); + } + else if(is_dp_block) + { + printf("[dp_block] bid:%3d, block_idx:%3d, tile_idx:%3d, iter_start:%d(%d | %d), " + "iter_end:%d (len:%d)\n", + bid, + block_idx, + tile_idx, + iter_end - current_iter_length, + tile_iter_start, + iter_start, + iter_end, + current_iter_length); + } + else + { + printf("[other ] bid:%3d, block_idx:%3d\n", bid, block_idx); + } + + // some validation check + for(auto i = iter_end - current_iter_length; i < iter_end; i++) + { + if(i >= valid_tile_record.size()) + { + printf("unexpected, current iter:%d larger than max:%d\n", + i, + valid_tile_record.size()); + return -1; + } + valid_tile_record[i] = 1; + } + + iter_end -= current_iter_length; + if(iter_end <= iter_start) + break; + } + } + + int untouched = 0; + for(auto i = 0; i < valid_tile_record.size(); i++) + { + if(valid_tile_record[i] != 1) + { + printf("untouched at %d (%d)\n", i, valid_tile_record.size()); + untouched++; + } + } + printf("untouched %d/%d, %s\n", + untouched, + valid_tile_record.size(), + untouched == 0 ? "valid" : "fail"); +} diff --git a/test/block_swizzle_test/rebuild.sh b/test/block_swizzle_test/rebuild.sh new file mode 100644 index 0000000000000000000000000000000000000000..b07eb5504886b4078988963d3e83533fdc327214 --- /dev/null +++ b/test/block_swizzle_test/rebuild.sh @@ -0,0 +1,3 @@ +CC=g++ + +$CC -Wall -std=c++17 -Iinclude -O3 block_swizzle_test.cpp -o block_swizzle_test.exe \ No newline at end of file diff --git a/test/block_swizzle_test/simple_args.h b/test/block_swizzle_test/simple_args.h new file mode 100644 index 0000000000000000000000000000000000000000..7b10456ce9a33263e5e79f2fec76282f09bf3478 --- /dev/null +++ b/test/block_swizzle_test/simple_args.h @@ -0,0 +1,159 @@ +#pragma once + +#include +#include +#include +#include +#include +#include +#include + +struct arg_content_t +{ + std::string name; // key + std::string value; + std::string help_text; +}; + +class simple_args_t +{ + public: + simple_args_t() {} + simple_args_t& insert(const std::string& name_, + const std::string& default_value_, + const std::string& help_text_) + { + arg_content_t arg{name_, default_value_, help_text_}; + + if(arg_map.count(arg.name) != 0) + { + std::cout << "arg:" << arg.name << "already exist" << std::endl; + } + else + { + arg_map[arg.name] = arg; + } + return *this; + } + void usage() + { + for(auto& content : arg_map) + { + std::vector help_text_lines; + size_t pos = 0; + for(size_t next_pos = content.second.help_text.find('\n', pos); + next_pos != std::string::npos;) + { + help_text_lines.push_back( + std::string(content.second.help_text.begin() + pos, + content.second.help_text.begin() + next_pos++)); + pos = next_pos; + next_pos = content.second.help_text.find('\n', pos); + } + help_text_lines.push_back(std::string(content.second.help_text.begin() + pos, + content.second.help_text.end())); + + int arg_name_width = 16 - content.second.name.length(); + arg_name_width = arg_name_width > 0 ? arg_name_width : 2; + std::cout << std::setw(4) << "-" << content.second.name << std::setw(arg_name_width) + << " " << help_text_lines[0] << std::endl; + + for(auto help_next_line = std::next(help_text_lines.begin()); + help_next_line != help_text_lines.end(); + ++help_next_line) + { + std::cout << std::setw(28) << " " << *help_next_line << std::endl; + } + } + } + bool parse(int argc, char* argv[], int start_index = 1) + { + if(argc <= start_index) + { + // std::cout << "not enough args (" << argc << ") with starting index " << start_index + // << std::endl; + return true; + } + for(int i = start_index; i < argc; i++) + { + std::string cur_arg = std::string(argv[i]); + if(cur_arg[0] != '-') + { + std::cout << "illegal input" << std::endl; + usage(); + return false; + } + else if(cur_arg[0] == '-' && cur_arg[1] == '?') + { + usage(); + return false; + } + else + { + size_t found_equal = cur_arg.find('='); + if(found_equal == std::string::npos || found_equal == (cur_arg.length() - 1)) + { + std::cout << "failed while parsing \"" << cur_arg << "\", " + << "arg must be in the form \"-name=value\"" << std::endl; + return false; + } + std::string arg_name = cur_arg.substr(1, found_equal - 1); + std::string arg_value = cur_arg.substr(found_equal + 1); + if(arg_map.count(arg_name) == 0) + { + std::cout << "no such arg \"" << arg_name << "\" registered" << std::endl; + return false; + } + arg_map[arg_name].value = arg_value; + } + } + return true; + } + + std::string get(const std::string& name) const { return get_str(name); } + + std::string get_str(const std::string& name) const + { + assert(arg_map.count(name) != 0); + std::string value = arg_map.at(name).value; + return value; + } + + int get_int(const std::string& name) const + { + assert(arg_map.count(name) != 0); + int value = atoi(arg_map.at(name).value.c_str()); + return value; + } + + uint32_t get_uint32(const std::string& name) const + { + assert(arg_map.count(name) != 0); + uint32_t value = strtoul(arg_map.at(name).value.c_str(), nullptr, 10); + return value; + } + + uint64_t get_uint64(const std::string& name) const + { + assert(arg_map.count(name) != 0); + uint64_t value = strtoull(arg_map.at(name).value.c_str(), nullptr, 10); + return value; + } + + double get_double(const std::string& name) const + { + assert(arg_map.count(name) != 0); + double value = atof(arg_map.at(name).value.c_str()); + return value; + } + + float get_float(const std::string& name) const + { + assert(arg_map.count(name) != 0); + float value = atof(arg_map.at(name).value.c_str()); + return value; + } + + private: + std::unordered_map arg_map; +}; diff --git a/test/block_to_ctile_map/test_block_to_ctile_map.cpp b/test/block_to_ctile_map/test_block_to_ctile_map.cpp index 55d9b59f489203cf68b5aebc1c5d9642801435c0..b8e349eda16f53d97b9655840f9803ef8e6a7ace 100644 --- a/test/block_to_ctile_map/test_block_to_ctile_map.cpp +++ b/test/block_to_ctile_map/test_block_to_ctile_map.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #include #include diff --git a/test/contraction/CMakeLists.txt b/test/contraction/CMakeLists.txt index ec44151f5325948af6ae458856018ddf23923d60..1f6e0ed34187f1cd66d7bc8a2e8f1d5a16982d42 100644 --- a/test/contraction/CMakeLists.txt +++ b/test/contraction/CMakeLists.txt @@ -1,6 +1,11 @@ add_gtest_executable(test_contraction test_contraction.cpp) target_link_libraries(test_contraction PRIVATE utility device_contraction_bilinear_instance device_contraction_scale_instance) -if(GPU_TARGETS MATCHES "gfx908" OR GPU_TARGETS MATCHES "gfx90a" OR GPU_TARGETS MATCHES "gfx940") +list(APPEND gpu_list gfx908 gfx90a gfx940 gfx941 gfx942) +set(target 0) +foreach(gpu IN LISTS GPU_TARGETS) + if(gpu IN_LIST gpu_list AND target EQUAL 0) add_gtest_executable(test_contraction_interface test_contraction_interface.cpp) target_link_libraries(test_contraction_interface PRIVATE utility device_contraction_bilinear_instance device_contraction_scale_instance) -endif() + set(target 1) + endif() +endforeach() diff --git a/test/contraction/test_contraction_interface.cpp b/test/contraction/test_contraction_interface.cpp index c9e720c597b1ca59e8924b250b3c021f51a17a80..12a307f5da9f5a5b9b12cc385901f4cfcc0fc418 100644 --- a/test/contraction/test_contraction_interface.cpp +++ b/test/contraction/test_contraction_interface.cpp @@ -38,7 +38,7 @@ class ContractionInstanceWrapper //#####################################| | | | Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| //#####################################| | | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| //#####################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | - DeviceContractionMultipleD_Xdl_CShuffle< NumDim, NumDim, NumDim, F32, F32, F32, F32, ck::Tuple, F32, Pass, Pass, Bilinear, GemmSpec, 1, 256, 256, 128, 16, 4, 4, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, ABlockTransferSrcVectorDim, 4, 4, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, BBlockTransferSrcVectorDim, 4, 4, 1, 1, 1, S<1, 16, 1, 16>, CDEBlockTransferScalarPerVector>; + DeviceContractionMultipleD_Xdl_CShuffle< NumDim, NumDim, NumDim, F32, F32, F32, F32, ck::Tuple, F32, Pass, Pass, Bilinear, GemmSpec, 1, 256, 256, 128, 16, 4, 4, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, ABlockTransferSrcVectorDim, 4, 4, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, BBlockTransferSrcVectorDim, 4, 4, 1, 1, 1, S<1, 16, 1, 16>, CDEBlockTransferScalarPerVector>; // clang-format on bool isSupported(std::vector& ADims, diff --git a/test/conv_util/conv_util.cpp b/test/conv_util/conv_util.cpp index 73797a7169e78cfeabac101ac0c7da366187edbf..6922bbbcc73dca4d0bc9e98ae49e8b504adb860a 100644 --- a/test/conv_util/conv_util.cpp +++ b/test/conv_util/conv_util.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #include #include diff --git a/test/convnd_bwd_data/CMakeLists.txt b/test/convnd_bwd_data/CMakeLists.txt index 4bfd21945d308a2159050a4528b7d9a85da26b57..f734b46f53ef886e096b77d529f15644b3d7b96d 100644 --- a/test/convnd_bwd_data/CMakeLists.txt +++ b/test/convnd_bwd_data/CMakeLists.txt @@ -1,4 +1,9 @@ -if(GPU_TARGETS MATCHES "gfx908" OR GPU_TARGETS MATCHES "gfx90a" OR GPU_TARGETS MATCHES "gfx940") +list(APPEND gpu_list gfx908 gfx90a gfx940 gfx941 gfx942) +set(target 0) +foreach(gpu IN LISTS GPU_TARGETS) + if(gpu IN_LIST gpu_list AND target EQUAL 0) add_gtest_executable(test_convnd_bwd_data convnd_bwd_data.cpp) target_link_libraries(test_convnd_bwd_data PRIVATE utility device_conv1d_bwd_data_instance device_conv2d_bwd_data_instance device_conv3d_bwd_data_instance) -endif() \ No newline at end of file + set(target 1) + endif() +endforeach() \ No newline at end of file diff --git a/test/convnd_bwd_data/convnd_bwd_data.cpp b/test/convnd_bwd_data/convnd_bwd_data.cpp index 70231d42ae5f9eb9a368c643b2cf4041ae5fac4e..9d2b6cf5770144e25b21874e6f4da93c523e76bb 100644 --- a/test/convnd_bwd_data/convnd_bwd_data.cpp +++ b/test/convnd_bwd_data/convnd_bwd_data.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #include #include diff --git a/test/convnd_fwd/CMakeLists.txt b/test/convnd_fwd/CMakeLists.txt index 058f9902667dba6a6ab3a750603a4a1e82c56d27..745aceffc932389c8a8b2b35f384dcafa6f5691f 100644 --- a/test/convnd_fwd/CMakeLists.txt +++ b/test/convnd_fwd/CMakeLists.txt @@ -1,4 +1,9 @@ -if(GPU_TARGETS MATCHES "gfx908" OR GPU_TARGETS MATCHES "gfx90a" OR GPU_TARGETS MATCHES "gfx940") +list(APPEND gpu_list gfx908 gfx90a gfx940 gfx941 gfx942) +set(target 0) +foreach(gpu IN LISTS GPU_TARGETS) + if(gpu IN_LIST gpu_list AND target EQUAL 0) add_gtest_executable(test_convnd_fwd convnd_fwd.cpp) target_link_libraries(test_convnd_fwd PRIVATE utility device_conv2d_fwd_instance) -endif() + set(target 1) + endif() +endforeach() diff --git a/test/convnd_fwd/convnd_fwd.cpp b/test/convnd_fwd/convnd_fwd.cpp index a1921a9bfbe0f1221ddada95c75835ab9c942292..fe8798ceb8b1086d7bea43008194845a47f6db43 100644 --- a/test/convnd_fwd/convnd_fwd.cpp +++ b/test/convnd_fwd/convnd_fwd.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #include #include diff --git a/test/data_type/CMakeLists.txt b/test/data_type/CMakeLists.txt index 088fbfec719f9e948686637fe9bec280f9926bd2..2b63727f19153dcc493bbb7ce4545b136a68b7db 100644 --- a/test/data_type/CMakeLists.txt +++ b/test/data_type/CMakeLists.txt @@ -2,3 +2,6 @@ if (USE_BITINT_EXTENSION_INT4) add_gtest_executable(test_int4 int4.cpp) target_link_libraries(test_int4 PRIVATE utility) endif() + +add_gtest_executable(test_fp8 fp8.cpp) +target_link_libraries(test_fp8 PRIVATE utility) diff --git a/test/data_type/fp8.cpp b/test/data_type/fp8.cpp new file mode 100644 index 0000000000000000000000000000000000000000..5004fe952775a74cc6b9736d1e34a8675d30187f --- /dev/null +++ b/test/data_type/fp8.cpp @@ -0,0 +1,123 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#include "gtest/gtest.h" +#include "ck/utility/data_type.hpp" +#include "ck/utility/type_convert.hpp" + +using ck::f8_convert_sr; +using ck::f8_t; +using ck::half_t; +using ck::type_convert; + +TEST(FP8, NumericLimits) +{ + EXPECT_EQ(ck::NumericLimits::Min(), 0x08); + EXPECT_EQ(ck::NumericLimits::Max(), 0x77); + EXPECT_EQ(ck::NumericLimits::Lowest(), 0xF7); + EXPECT_EQ(ck::NumericLimits::QuietNaN(), 0x80); +} + +TEST(FP8, ConvertFP32Nearest) +{ + // fix the tolerance value + float abs_tol = 1e-6; + // convert 0 float to fp8 and back, check if holds + ASSERT_NEAR(0.0f, type_convert(type_convert(0.0f)), abs_tol); + // convert minimal float to fp8 and back, check if holds + ASSERT_NEAR(std::numeric_limits::min(), + type_convert(type_convert(std::numeric_limits::min())), + abs_tol); + // convert maximal f8_t to float and check if equal to 240.0 + ASSERT_NEAR(240.0f, type_convert(type_convert(240.0f)), abs_tol); + // convert maximal float to fp8 and back, check if clipped to 240.0 + ASSERT_NEAR(240.0f, + type_convert(type_convert(std::numeric_limits::max())), + abs_tol); + // convert inf float to f8_t and check if it is qNan + ASSERT_NEAR(0x80, type_convert(std::numeric_limits::infinity()), abs_tol); + // positive float value to fp8 and back, check if holds + float pos_float = 0.0078125f; + ASSERT_NEAR(pos_float, type_convert(type_convert(pos_float)), abs_tol); + // negative float value to fp8 and back, check if holds + float neg_float = -0.0156250f; + ASSERT_NEAR(neg_float, type_convert(type_convert(neg_float)), abs_tol); +} + +TEST(FP8, ConvertFP32Stochastic) +{ + // fix the tolerance value + float abs_tol = 1e-6; + // convert 0 float to fp8 and back, check if holds + ASSERT_NEAR(0.0f, type_convert(f8_convert_sr(0.0f)), abs_tol); + // convert minimal float to fp8 and back, check if holds + ASSERT_NEAR(std::numeric_limits::min(), + type_convert(f8_convert_sr(std::numeric_limits::min())), + abs_tol); + // convert maximal f8_t to float and check if equal to 240.0 + ASSERT_NEAR(240.0f, type_convert(f8_convert_sr(240.0f)), abs_tol); + // convert maximal float to fp8 and back, check if clipped to 240.0 + ASSERT_NEAR(240.0f, + type_convert(f8_convert_sr(std::numeric_limits::max())), + abs_tol); + // convert inf float to f8_t and check if it is qNan + ASSERT_NEAR(0x80, f8_convert_sr(std::numeric_limits::infinity()), abs_tol); + // positive float value to fp8 and back, check if holds + float pos_float = 0.0078125f; + ASSERT_NEAR(pos_float, type_convert(f8_convert_sr(pos_float)), abs_tol); + // negative float value to fp8 and back, check if holds + float neg_float = -0.0156250f; + ASSERT_NEAR(neg_float, type_convert(f8_convert_sr(neg_float)), abs_tol); +} + +TEST(FP8, ConvertFP16Nearest) +{ + // fix the tolerance value + float abs_tol = 1e-3; + // convert 0 fp16 to fp8 and back, check if holds + ASSERT_NEAR(half_t{0.0}, type_convert(type_convert(half_t{0.0})), abs_tol); + // convert minimal fp16 to fp8 and back, check if holds + ASSERT_NEAR(ck::NumericLimits::Min(), + type_convert(type_convert(ck::NumericLimits::Min())), + abs_tol); + // convert maximal f8_t to fp16 and check if equal to 240.0 + ASSERT_NEAR(half_t{240.0}, type_convert(type_convert(half_t{240.0})), abs_tol); + // convert maximal fp16 to fp8 and back, check if clipped to 240.0 + ASSERT_NEAR(half_t{240.0}, + type_convert(type_convert(ck::NumericLimits::Max())), + abs_tol); + // convert QuietNaN fp16 to f8_t and check if it is QuietNaN + ASSERT_NEAR(0x80, type_convert(ck::NumericLimits::QuietNaN()), abs_tol); + // positive fp16 value to fp8 and back, check if holds + half_t pos_half = half_t{0.0078125}; + ASSERT_NEAR(pos_half, type_convert(type_convert(pos_half)), abs_tol); + // negative fp16 value to fp8 and back, check if holds + half_t neg_half = half_t{-0.0156250}; + ASSERT_NEAR(neg_half, type_convert(type_convert(neg_half)), abs_tol); +} + +TEST(FP8, ConvertFP16Stochastic) +{ + // fix the tolerance value + float abs_tol = 1e-3; + // convert 0 fp16 to fp8 and back, check if holds + ASSERT_NEAR(half_t{0.0}, type_convert(f8_convert_sr(half_t{0.0})), abs_tol); + // convert minimal fp16 to fp8 and back, check if holds + ASSERT_NEAR(ck::NumericLimits::Min(), + type_convert(f8_convert_sr(ck::NumericLimits::Min())), + abs_tol); + // convert maximal f8_t to fp16 and check if equal to 240.0 + ASSERT_NEAR(half_t{240.0}, type_convert(f8_convert_sr(half_t{240.0})), abs_tol); + // convert maximal fp16 to fp8 and back, check if clipped to 240.0 + ASSERT_NEAR(half_t{240.0}, + type_convert(f8_convert_sr(ck::NumericLimits::Max())), + abs_tol); + // convert QuietNaN fp16 to f8_t and check if it is QuietNaN + ASSERT_NEAR(0x80, f8_convert_sr(ck::NumericLimits::QuietNaN()), abs_tol); + // positive fp16 value to fp8 and back, check if holds + half_t pos_half = half_t{0.0078125}; + ASSERT_NEAR(pos_half, type_convert(f8_convert_sr(pos_half)), abs_tol); + // negative fp16 value to fp8 and back, check if holds + half_t neg_half = half_t{-0.0156250}; + ASSERT_NEAR(neg_half, type_convert(f8_convert_sr(neg_half)), abs_tol); +} diff --git a/test/data_type/int4.cpp b/test/data_type/int4.cpp index 252a450bf96896e677cbc5f402ccfe722192d9bc..07549c1c48bb807c9e50520e173d528bf27b3bc9 100644 --- a/test/data_type/int4.cpp +++ b/test/data_type/int4.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #include #include diff --git a/test/elementwise_normalization/CMakeLists.txt b/test/elementwise_normalization/CMakeLists.txt index a20eb263256af119453c81f399ea3605bd8b71cc..74a3e4999ee7064c98b503baaccf84465a43752c 100644 --- a/test/elementwise_normalization/CMakeLists.txt +++ b/test/elementwise_normalization/CMakeLists.txt @@ -1,7 +1,6 @@ -add_custom_target(test_elementwise_normalization) - -add_gtest_executable(test_elementwise_layernorm_fp16 test_elementwise_layernorm_fp16.cpp) - -target_link_libraries(test_elementwise_layernorm_fp16 PRIVATE utility device_elementwise_normalization_instance) - -add_dependencies(test_elementwise_normalization test_elementwise_layernorm_fp16) +if(DTYPES MATCHES "fp16" OR NOT DEFINED DTYPES) + add_custom_target(test_elementwise_normalization) + add_gtest_executable(test_elementwise_layernorm_fp16 test_elementwise_layernorm_fp16.cpp) + target_link_libraries(test_elementwise_layernorm_fp16 PRIVATE utility device_elementwise_normalization_instance) + add_dependencies(test_elementwise_normalization test_elementwise_layernorm_fp16) +endif() \ No newline at end of file diff --git a/test/elementwise_normalization/test_elementwise_layernorm_fp16.cpp b/test/elementwise_normalization/test_elementwise_layernorm_fp16.cpp index e80995c4f08622432f4e6c3d3c5cf3200fa8c460..d5ce77dc2b956b02439a42eaa341285217780ac8 100644 --- a/test/elementwise_normalization/test_elementwise_layernorm_fp16.cpp +++ b/test/elementwise_normalization/test_elementwise_layernorm_fp16.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #include "gtest/gtest.h" #include "profiler/profile_elementwise_layernorm_impl.hpp" diff --git a/test/gemm/CMakeLists.txt b/test/gemm/CMakeLists.txt index 186b06d9117072dd7ecb7f74fc58461795b412c5..7a8836bfefed6409032a5e27c606bd5f1ac5af94 100644 --- a/test/gemm/CMakeLists.txt +++ b/test/gemm/CMakeLists.txt @@ -1,19 +1,12 @@ +if(DTYPES MATCHES "fp32" OR NOT DEFINED DTYPES) add_test_executable(test_gemm_fp32 gemm_fp32.cpp) target_link_libraries(test_gemm_fp32 PRIVATE utility) target_link_libraries(test_gemm_fp32 PRIVATE device_gemm_instance) - +endif() +if(DTYPES MATCHES "fp16" OR NOT DEFINED DTYPES) add_test_executable(test_gemm_fp16 gemm_fp16.cpp) target_link_libraries(test_gemm_fp16 PRIVATE utility) target_link_libraries(test_gemm_fp16 PRIVATE device_gemm_instance) - -add_test_executable(test_gemm_bf16 gemm_bf16.cpp) -target_link_libraries(test_gemm_bf16 PRIVATE utility) -target_link_libraries(test_gemm_bf16 PRIVATE device_gemm_instance) - -add_test_executable(test_gemm_int8 gemm_int8.cpp) -target_link_libraries(test_gemm_int8 PRIVATE utility) -target_link_libraries(test_gemm_int8 PRIVATE device_gemm_instance) - add_library(gemm_standalone_xdl_fp16_instances STATIC instance/gemm_f16_nn_instance.cpp instance/gemm_f16_nt_instance.cpp @@ -24,3 +17,14 @@ add_library(gemm_standalone_xdl_fp16_instances STATIC add_test_executable(test_gemm_standalone_xdl_fp16 gemm_standalone_xdl_fp16.cpp) target_link_libraries(test_gemm_standalone_xdl_fp16 PRIVATE gemm_standalone_xdl_fp16_instances utility) target_include_directories(test_gemm_standalone_xdl_fp16 PRIVATE instance/) +endif() +if(DTYPES MATCHES "bf16" OR NOT DEFINED DTYPES) +add_test_executable(test_gemm_bf16 gemm_bf16.cpp) +target_link_libraries(test_gemm_bf16 PRIVATE utility) +target_link_libraries(test_gemm_bf16 PRIVATE device_gemm_instance) +endif() +if(DTYPES MATCHES "int8" OR NOT DEFINED DTYPES) +add_test_executable(test_gemm_int8 gemm_int8.cpp) +target_link_libraries(test_gemm_int8 PRIVATE utility) +target_link_libraries(test_gemm_int8 PRIVATE device_gemm_instance) +endif() \ No newline at end of file diff --git a/test/gemm/gemm_bf16.cpp b/test/gemm/gemm_bf16.cpp index 5290d466323277daa4e61af1a8273b816a790120..cde5c45aeab59453488ea89e3942550531b65b87 100644 --- a/test/gemm/gemm_bf16.cpp +++ b/test/gemm/gemm_bf16.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #include #include diff --git a/test/gemm/gemm_fp16.cpp b/test/gemm/gemm_fp16.cpp index 92e225def29623c7dcfc8590a5f805052dedb54d..cad250c6fb4d9eb935653d79999db3ff77a6932e 100644 --- a/test/gemm/gemm_fp16.cpp +++ b/test/gemm/gemm_fp16.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #include #include diff --git a/test/gemm/gemm_fp32.cpp b/test/gemm/gemm_fp32.cpp index 5d8c4881b621504fe621d056cbcf68bcb04622fc..c35aa77ea7dad4a7bf56e616d3456be932f44d8e 100644 --- a/test/gemm/gemm_fp32.cpp +++ b/test/gemm/gemm_fp32.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #include #include diff --git a/test/gemm/gemm_fp64.cpp b/test/gemm/gemm_fp64.cpp index 85d7f95bf4ad9ea8df19fdf2600c191f7eabcbf2..e67c8ba4f32331619fa5979c0e66226253993af1 100644 --- a/test/gemm/gemm_fp64.cpp +++ b/test/gemm/gemm_fp64.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #include #include diff --git a/test/gemm/gemm_int8.cpp b/test/gemm/gemm_int8.cpp index e73b22ce9c847f704fe55ad31cc83661b1ebab87..6ece05e306e66cb1013a846f3bdfda6ba4a195f4 100644 --- a/test/gemm/gemm_int8.cpp +++ b/test/gemm/gemm_int8.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #include #include diff --git a/test/gemm/gemm_standalone_xdl_fp16.cpp b/test/gemm/gemm_standalone_xdl_fp16.cpp index 32a243e0f694c0fa1e57178642d08885dd69d6a3..201a49dcd390090c056ec911f70fee9bd61b61af 100644 --- a/test/gemm/gemm_standalone_xdl_fp16.cpp +++ b/test/gemm/gemm_standalone_xdl_fp16.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #include "gemm_util.hpp" diff --git a/test/gemm/gemm_util.hpp b/test/gemm/gemm_util.hpp index 9057c0af896891107683605575ee17aeb460c6cf..6c46f4ee89539b1e687b79f40803cfafa62528fc 100644 --- a/test/gemm/gemm_util.hpp +++ b/test/gemm/gemm_util.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/test/gemm/instance/gemm_f16_nn_instance.cpp b/test/gemm/instance/gemm_f16_nn_instance.cpp index 4d65c5876cd92a04cf04a99d3f3655e50c586a8d..9016257f131ffa4a130c009256e66b1d5abe709a 100644 --- a/test/gemm/instance/gemm_f16_nn_instance.cpp +++ b/test/gemm/instance/gemm_f16_nn_instance.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #include diff --git a/test/gemm/instance/gemm_f16_nn_instance.hpp b/test/gemm/instance/gemm_f16_nn_instance.hpp index 5ae3928dc972f53eb1bb963cd2130e110ff19ae0..e174b99a1d8156540bdc3a0fcd16bb9870e96204 100644 --- a/test/gemm/instance/gemm_f16_nn_instance.hpp +++ b/test/gemm/instance/gemm_f16_nn_instance.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #include diff --git a/test/gemm/instance/gemm_f16_nt_instance.cpp b/test/gemm/instance/gemm_f16_nt_instance.cpp index 431ff1e62e79fc126deac89194996f2db3e2cb01..27103b88d4c25cd8658e8e58f2b0616377ccb99c 100644 --- a/test/gemm/instance/gemm_f16_nt_instance.cpp +++ b/test/gemm/instance/gemm_f16_nt_instance.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #include diff --git a/test/gemm/instance/gemm_f16_nt_instance.hpp b/test/gemm/instance/gemm_f16_nt_instance.hpp index 99f9ffba4562cd03baff678302e0900a1b842e32..c624425e69573e8293bb9b667f35cfee070c2806 100644 --- a/test/gemm/instance/gemm_f16_nt_instance.hpp +++ b/test/gemm/instance/gemm_f16_nt_instance.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #include diff --git a/test/gemm/instance/gemm_f16_tn_instance.cpp b/test/gemm/instance/gemm_f16_tn_instance.cpp index 6f5dbc311ebb81501eedc218b43003c777424be2..5b11f4dad9000672b1c2cac01cdb5f1eb4a3ad34 100644 --- a/test/gemm/instance/gemm_f16_tn_instance.cpp +++ b/test/gemm/instance/gemm_f16_tn_instance.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #include diff --git a/test/gemm/instance/gemm_f16_tn_instance.hpp b/test/gemm/instance/gemm_f16_tn_instance.hpp index 62388aeb398dff03c1f997c5ece4d53560492abd..563e10600adde840ac967f5e980362d53e6460e9 100644 --- a/test/gemm/instance/gemm_f16_tn_instance.hpp +++ b/test/gemm/instance/gemm_f16_tn_instance.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #include diff --git a/test/gemm/instance/gemm_f16_tt_instance.cpp b/test/gemm/instance/gemm_f16_tt_instance.cpp index b6ef5b1cd21a5314c83bedfcdef2d9c3d00906b9..9032150f0c5b7449244c23beea13df71d1e91ea5 100644 --- a/test/gemm/instance/gemm_f16_tt_instance.cpp +++ b/test/gemm/instance/gemm_f16_tt_instance.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #include diff --git a/test/gemm/instance/gemm_f16_tt_instance.hpp b/test/gemm/instance/gemm_f16_tt_instance.hpp index 9d75b4e48cbba7905c3eebe2b30e8c663e31cbeb..62914d7ac22cb9a7ece1c321ef1a72368788ee2f 100644 --- a/test/gemm/instance/gemm_f16_tt_instance.hpp +++ b/test/gemm/instance/gemm_f16_tt_instance.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #include diff --git a/test/gemm/instance/gemm_wavelet_f16_tn_instance.cpp b/test/gemm/instance/gemm_wavelet_f16_tn_instance.cpp index 51c014a91a4e35ce16bd59feb16683899e6b4909..23ff89da6005b9c612669c3932a763ef65433f19 100644 --- a/test/gemm/instance/gemm_wavelet_f16_tn_instance.cpp +++ b/test/gemm/instance/gemm_wavelet_f16_tn_instance.cpp @@ -1,11 +1,11 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #include #include "ck/ck.hpp" #include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" -#include "ck/tensor_operation/gpu/device/device_gemm_xdl_waveletmodel_cshuffle.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_gemm_xdl_waveletmodel_cshuffle.hpp" #include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" #include "ck/tensor_operation/gpu/device/tensor_layout.hpp" #include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" diff --git a/test/gemm/instance/gemm_wavelet_f16_tn_instance.hpp b/test/gemm/instance/gemm_wavelet_f16_tn_instance.hpp index 110fc5f7d9a8e78507085c7bf870d8d733417076..ef269d78ee5e11c8813b2f4787583c34530d7b2d 100644 --- a/test/gemm/instance/gemm_wavelet_f16_tn_instance.hpp +++ b/test/gemm/instance/gemm_wavelet_f16_tn_instance.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #include #include diff --git a/test/gemm/run_gemm_test.inc b/test/gemm/run_gemm_test.inc index ec27729b3c9f02148cdd5537fe7f12e7421a7a4d..d208bb5a7b95a1775815dfa95f55befa5c401ea9 100644 --- a/test/gemm/run_gemm_test.inc +++ b/test/gemm/run_gemm_test.inc @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. int run_gemm_test() { diff --git a/test/gemm_layernorm/CMakeLists.txt b/test/gemm_layernorm/CMakeLists.txt index b2a5178ac4817458589f254b60b66b6ce6d1a09d..ba0a99b67f3be3ebc84eec5a4616f22d00bcab94 100644 --- a/test/gemm_layernorm/CMakeLists.txt +++ b/test/gemm_layernorm/CMakeLists.txt @@ -1,6 +1,13 @@ -if(GPU_TARGETS MATCHES "gfx908" OR GPU_TARGETS MATCHES "gfx90a" OR GPU_TARGETS MATCHES "gfx940") +list(APPEND gpu_list gfx908 gfx90a gfx940 gfx941 gfx942) +set(target 0) +foreach(gpu IN LISTS GPU_TARGETS) + if(gpu IN_LIST gpu_list AND target EQUAL 0) + if(DTYPES MATCHES "fp16" OR NOT DEFINED DTYPES) add_custom_target(test_gemm_layernorm) add_gtest_executable(test_gemm_add_relu_add_layernorm_fp16 test_gemm_add_relu_add_layernorm_fp16.cpp) target_link_libraries(test_gemm_add_relu_add_layernorm_fp16 PRIVATE utility device_gemm_add_relu_add_layernorm_instance) add_dependencies(test_gemm_layernorm test_gemm_add_relu_add_layernorm_fp16) -endif() + set(target 1) + endif() + endif() +endforeach() diff --git a/test/gemm_layernorm/test_gemm_add_relu_add_layernorm_fp16.cpp b/test/gemm_layernorm/test_gemm_add_relu_add_layernorm_fp16.cpp index 740c63aa7ee5f7b6188872b11c14ab1e3bea9280..3f059968784cd908a2969221d52ff89a2db1400f 100644 --- a/test/gemm_layernorm/test_gemm_add_relu_add_layernorm_fp16.cpp +++ b/test/gemm_layernorm/test_gemm_add_relu_add_layernorm_fp16.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #include "gtest/gtest.h" #include "profiler/profile_gemm_add_relu_add_layernorm_impl.hpp" diff --git a/test/gemm_reduce/CMakeLists.txt b/test/gemm_reduce/CMakeLists.txt index 349f892c19b8c599b2d6bbbb8ec7fd90652ce020..43c8d60745fe48123d979fe4b4a13c1a1c858632 100644 --- a/test/gemm_reduce/CMakeLists.txt +++ b/test/gemm_reduce/CMakeLists.txt @@ -1,3 +1,5 @@ -add_test_executable(test_gemm_reduce_fp16 gemm_reduce_fp16.cpp) -target_link_libraries(test_gemm_reduce_fp16 PRIVATE utility) -target_link_libraries(test_gemm_reduce_fp16 PRIVATE device_gemm_reduce_instance) +if(DTYPES MATCHES "fp16" OR NOT DEFINED DTYPES) + add_test_executable(test_gemm_reduce_fp16 gemm_reduce_fp16.cpp) + target_link_libraries(test_gemm_reduce_fp16 PRIVATE utility) + target_link_libraries(test_gemm_reduce_fp16 PRIVATE device_gemm_reduce_instance) +endif() \ No newline at end of file diff --git a/test/gemm_reduce/gemm_reduce_fp16.cpp b/test/gemm_reduce/gemm_reduce_fp16.cpp index 029165ece125abaa99dc1605bfbc2467b02ecc26..35a149f52c9f82634b4f63f837cbb6bbc808870e 100644 --- a/test/gemm_reduce/gemm_reduce_fp16.cpp +++ b/test/gemm_reduce/gemm_reduce_fp16.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #include diff --git a/test/gemm_split_k/CMakeLists.txt b/test/gemm_split_k/CMakeLists.txt index 2274854f88f5143bf42ac00b6801588726c7e8d5..caf30fca595122a4d79935ea3bdbc10d83f364d1 100644 --- a/test/gemm_split_k/CMakeLists.txt +++ b/test/gemm_split_k/CMakeLists.txt @@ -1,4 +1,9 @@ -if(GPU_TARGETS MATCHES "gfx908" OR GPU_TARGETS MATCHES "gfx90a" OR GPU_TARGETS MATCHES "gfx940") +list(APPEND gpu_list gfx908 gfx90a gfx940 gfx941 gfx942) +set(target 0) +foreach(gpu IN LISTS GPU_TARGETS) + if(gpu IN_LIST gpu_list AND target EQUAL 0) add_gtest_executable(test_gemm_splitk test_gemm_splitk.cpp) target_link_libraries(test_gemm_splitk PRIVATE utility device_gemm_splitk_instance) -endif() + set(target 1) + endif() +endforeach() diff --git a/test/grouped_convnd_bwd_data/CMakeLists.txt b/test/grouped_convnd_bwd_data/CMakeLists.txt new file mode 100644 index 0000000000000000000000000000000000000000..436ab2af76b8716a90ed350e307705a3227f17bc --- /dev/null +++ b/test/grouped_convnd_bwd_data/CMakeLists.txt @@ -0,0 +1,6 @@ +if(GPU_TARGETS MATCHES "gfx908" OR GPU_TARGETS MATCHES "gfx90a" OR GPU_TARGETS MATCHES "gfx940") + add_gtest_executable(test_grouped_convnd_bwd_data test_grouped_convnd_bwd_data.cpp) + target_link_libraries(test_grouped_convnd_bwd_data PRIVATE utility device_grouped_conv2d_bwd_data_instance device_grouped_conv3d_bwd_data_instance) + add_gtest_executable(test_grouped_convnd_bwd_data_interface test_grouped_convnd_bwd_data_interface.cpp) + target_link_libraries(test_grouped_convnd_bwd_data_interface PRIVATE utility device_grouped_conv2d_bwd_data_instance) +endif() \ No newline at end of file diff --git a/test/grouped_convnd_bwd_data/test_grouped_convnd_bwd_data.cpp b/test/grouped_convnd_bwd_data/test_grouped_convnd_bwd_data.cpp new file mode 100644 index 0000000000000000000000000000000000000000..bdad37fe14741471130f1b7cdbcb1212bb170eda --- /dev/null +++ b/test/grouped_convnd_bwd_data/test_grouped_convnd_bwd_data.cpp @@ -0,0 +1,112 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#include +#include +#include +#include +#include + +#include + +#include "profiler/profile_grouped_conv_bwd_data_impl.hpp" + +template +class TestGroupedConvndBwdData : public ::testing::Test +{ + protected: + using DataType = std::tuple_element_t<0, Tuple>; + using OutLayout = std::tuple_element_t<1, Tuple>; + using WeiLayout = std::tuple_element_t<2, Tuple>; + using InLayout = std::tuple_element_t<3, Tuple>; + + std::vector conv_params; + + template + void Run() + { + EXPECT_FALSE(conv_params.empty()); + bool pass = true; + for(auto& param : conv_params) + { + pass = pass && ck::profiler::profile_grouped_conv_bwd_data_impl( + true, // do_verification + 1, // init_method: integer value + false, // do_log + false, // time_kernel + param); + } + EXPECT_TRUE(pass); + } +}; + +using namespace ck::tensor_layout::convolution; + +using KernelTypes2d = ::testing::Types, + std::tuple, + std::tuple, + std::tuple, + std::tuple, + std::tuple>; + +using KernelTypes3d = ::testing::Types, + std::tuple, + std::tuple, + std::tuple, + std::tuple, + std::tuple>; + +template +class TestGroupedConvndBwdData2d : public TestGroupedConvndBwdData +{ +}; + +template +class TestGroupedConvndBwdData3d : public TestGroupedConvndBwdData +{ +}; + +TYPED_TEST_SUITE(TestGroupedConvndBwdData2d, KernelTypes2d); +TYPED_TEST_SUITE(TestGroupedConvndBwdData3d, KernelTypes3d); + +TYPED_TEST(TestGroupedConvndBwdData2d, Test2D) +{ + this->conv_params.clear(); + + this->conv_params.push_back( + {2, 2, 4, 192, 192, {3, 3}, {28, 28}, {1, 1}, {1, 1}, {1, 1}, {1, 1}}); + this->conv_params.push_back( + {2, 2, 128, 128, 256, {3, 3}, {14, 14}, {1, 1}, {1, 1}, {1, 1}, {1, 1}}); + this->conv_params.push_back( + {2, 2, 128, 128, 256, {1, 1}, {7, 7}, {2, 2}, {1, 1}, {0, 0}, {0, 0}}); + this->conv_params.push_back( + {2, 2, 128, 128, 256, {1, 1}, {3, 3}, {1, 1}, {1, 1}, {0, 0}, {0, 0}}); + this->conv_params.push_back({2, 1, 1, 1, 32, {8, 8}, {32, 32}, {1, 1}, {1, 1}, {1, 1}, {1, 1}}); + this->conv_params.push_back({2, 1, 1, 64, 3, {8, 8}, {32, 32}, {1, 1}, {1, 1}, {1, 1}, {1, 1}}); + this->conv_params.push_back({2, 1, 1, 1, 1, {8, 8}, {32, 32}, {1, 1}, {1, 1}, {1, 1}, {1, 1}}); + this->template Run<2>(); +} + +TYPED_TEST(TestGroupedConvndBwdData3d, Test3D) +{ + this->conv_params.clear(); + this->conv_params.push_back( + {3, 2, 16, 128, 256, {1, 1, 1}, {7, 7, 7}, {2, 2, 2}, {1, 1, 1}, {0, 0, 0}, {0, 0, 0}}); + this->conv_params.push_back( + {3, 2, 2, 128, 256, {3, 3, 3}, {14, 14, 3}, {1, 1, 1}, {1, 1, 1}, {1, 1, 1}, {1, 1, 1}}); + this->conv_params.push_back( + {3, 2, 32, 128, 256, {1, 1, 1}, {3, 3, 3}, {1, 1, 1}, {1, 1, 1}, {0, 0, 0}, {0, 0, 0}}); + this->conv_params.push_back( + {3, 1, 1, 1, 32, {3, 3, 3}, {32, 32, 32}, {1, 1, 1}, {1, 1, 1}, {1, 1, 1}, {1, 1, 1}}); + this->conv_params.push_back( + {3, 1, 1, 64, 3, {3, 3, 3}, {32, 32, 32}, {1, 1, 1}, {1, 1, 1}, {1, 1, 1}, {1, 1, 1}}); + this->conv_params.push_back( + {3, 1, 1, 1, 1, {3, 3, 3}, {32, 32, 32}, {1, 1, 1}, {1, 1, 1}, {1, 1, 1}, {1, 1, 1}}); + this->template Run<3>(); +} diff --git a/test/grouped_convnd_bwd_data/test_grouped_convnd_bwd_data_interface.cpp b/test/grouped_convnd_bwd_data/test_grouped_convnd_bwd_data_interface.cpp new file mode 100644 index 0000000000000000000000000000000000000000..bc592ba665604f93bdcd3b545152376d14d55c88 --- /dev/null +++ b/test/grouped_convnd_bwd_data/test_grouped_convnd_bwd_data_interface.cpp @@ -0,0 +1,178 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#include +#include +#include +#include +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/convolution_backward_data_specialization.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_data_multiple_d_xdl_cshuffle_v1.hpp" + +#include "ck/library/utility/convolution_parameter.hpp" +#include "ck/library/utility/algorithm.hpp" +#include "ck/library/utility/convolution_host_tensor_descriptor_helper.hpp" + +#include + +using DataType = ck::half_t; +using AccDataType = float; +using Pass = ck::tensor_operation::element_wise::PassThrough; + +template +using S = ck::Sequence; +using ConvBackwardDataSpecialization = + ck::tensor_operation::device::ConvolutionBackwardDataSpecialization; + +static constexpr auto ConvBwdDataDefault = ConvBackwardDataSpecialization::Default; +static constexpr auto Filter1x1Stride1Pad0 = ConvBackwardDataSpecialization::Filter1x1Stride1Pad0; + +template +class TestGroupedConvndBwdData : public ::testing::Test +{ + protected: + static constexpr ck::index_t NDimSpatial = 2; + + using OutLayout = std::tuple_element_t<0, Tuple>; + using WeiLayout = std::tuple_element_t<1, Tuple>; + using InLayout = std::tuple_element_t<2, Tuple>; + + // clang-format off + using GroupedConvBwdDataDeviceInstance = ck::tensor_operation::device::DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1 + // ######| NDimSpatial| ALayout| BLayout| DsLayout| ELayout| AData| BData| AccData| CShuffle| DsData| EData| AElementwise| BElementwise| CDEElementwise| ConvolutionBackward| DoPad| DoPad| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffleMXdl| CShuffleNXdl| CDEBlockTransfer| CDEBlockTransfer| + // ######| | | | | | Type| Type| Type| DataType| Type| Type| Operation| Operation| Operation| DataSpecialization| GemmM| GemmN| PrefetchStage| Size| Block| Block| Block| | | XDL| XDL| PerWave| PerWave| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraN| PerWave| PerWave| _MBlock_MPerBlock| ScalarPerVector| + // ######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | Lengths_AK0_M_AK1| ArrangeOrder| | | PerVector| PerVector_AK1| | Lengths_BK0_N_BK1| ArrangeOrder| | | PerVector| PerVector_BK1| | PerShuffle| PerShuffle| _NBlock_NPerBlock| _NPerBlock| + // ######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + < NDimSpatial, OutLayout, WeiLayout, ck::Tuple<>, InLayout, DataType, DataType, AccDataType, DataType, ck::Tuple<>, DataType, Pass, Pass, Pass, ConvSpec, true, true, 1, 256, 128, 256, 32, 8, 2, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<1, 32, 1, 8>, 8>; + // clang-format on + + ck::utils::conv::ConvParam conv_param; + + template + bool Run() + { + + const auto out_g_n_k_wos_desc = + ck::utils::conv::make_output_host_tensor_descriptor_g_n_k_wos_packed( + conv_param); + + const auto wei_g_k_c_xs_desc = + ck::utils::conv::make_weight_host_tensor_descriptor_g_k_c_xs_packed( + conv_param); + + const auto in_g_n_c_wis_desc = + ck::utils::conv::make_input_host_tensor_descriptor_g_n_c_wis_packed( + conv_param); + + std::array out_lengths{}; + std::array out_strides{}; + std::array wei_lengths{}; + std::array wei_strides{}; + std::array in_lengths{}; + std::array in_strides{}; + std::array conv_filter_strides{}; + std::array conv_filter_dilations{}; + std::array input_left_pads{}; + std::array input_right_pads{}; + + auto copy = [](const auto& x, auto& y) { ck::ranges::copy(x, y.begin()); }; + + copy(out_g_n_k_wos_desc.GetLengths(), out_lengths); + copy(out_g_n_k_wos_desc.GetStrides(), out_strides); + copy(wei_g_k_c_xs_desc.GetLengths(), wei_lengths); + copy(wei_g_k_c_xs_desc.GetStrides(), wei_strides); + copy(in_g_n_c_wis_desc.GetLengths(), in_lengths); + copy(in_g_n_c_wis_desc.GetStrides(), in_strides); + copy(conv_param.conv_filter_strides_, conv_filter_strides); + copy(conv_param.conv_filter_dilations_, conv_filter_dilations); + copy(conv_param.input_left_pads_, input_left_pads); + copy(conv_param.input_right_pads_, input_right_pads); + + auto conv = GroupedConvBwdDataDeviceInstance{}; + + auto argument = conv.MakeArgument(nullptr, + nullptr, + std::array{}, + nullptr, + out_lengths, + out_strides, + wei_lengths, + wei_strides, + {}, + {}, + in_lengths, + in_strides, + conv_filter_strides, + conv_filter_dilations, + input_left_pads, + input_right_pads, + Pass{}, + Pass{}, + Pass{}); + return conv.IsSupportedArgument(argument); + } +}; + +using GNHWC = ck::tensor_layout::convolution::GNHWC; +using NHWGC = ck::tensor_layout::convolution::NHWGC; + +using GKYXC = ck::tensor_layout::convolution::GKYXC; + +using GNHWK = ck::tensor_layout::convolution::GNHWK; +using NHWGK = ck::tensor_layout::convolution::NHWGK; + +using KernelTypes = + ::testing::Types, std::tuple>; + +template +class TestGroupedConvndBwdDataDefault : public TestGroupedConvndBwdData +{ +}; + +template +class TestGroupedConvndBwdDataFilter1x1 + : public TestGroupedConvndBwdData +{ +}; + +TYPED_TEST_SUITE(TestGroupedConvndBwdDataDefault, KernelTypes); +TYPED_TEST_SUITE(TestGroupedConvndBwdDataFilter1x1, KernelTypes); + +TYPED_TEST(TestGroupedConvndBwdDataFilter1x1, SpecializationCheck) +{ + // Check filter 3,3 instead of 1,1 + this->conv_param = {2, 2, 4, 192, 192, {3, 3}, {28, 28}, {1, 1}, {1, 1}, {0, 0}, {0, 0}}; + bool is_supported = this->template Run<2>(); + EXPECT_FALSE(is_supported); + + // Check strides 2,2 instead of 1,1 + this->conv_param = {2, 2, 4, 192, 192, {1, 1}, {28, 28}, {2, 2}, {1, 1}, {0, 0}, {0, 0}}; + is_supported = this->template Run<2>(); + EXPECT_FALSE(is_supported); + + // Check with pad + this->conv_param = {2, 2, 4, 192, 192, {1, 1}, {28, 28}, {1, 1}, {1, 1}, {1, 1}, {1, 1}}; + is_supported = this->template Run<2>(); + EXPECT_FALSE(is_supported); + + // Supported version + this->conv_param = {2, 2, 4, 192, 192, {1, 1}, {28, 28}, {1, 1}, {1, 1}, {0, 0}, {0, 0}}; + is_supported = this->template Run<2>(); + EXPECT_TRUE(is_supported); +} + +TYPED_TEST(TestGroupedConvndBwdDataDefault, VectorLoadCheck) +{ + // vector load for A + this->conv_param = {2, 2, 128, 129, 256, {1, 1}, {7, 7}, {2, 2}, {1, 1}, {0, 0}, {0, 0}}; + bool is_supported = this->template Run<2>(); + EXPECT_FALSE(is_supported); + // vector load for B, E, Ds + this->conv_param = {2, 2, 128, 128, 257, {1, 1}, {7, 7}, {2, 2}, {1, 1}, {0, 0}, {0, 0}}; + is_supported = this->template Run<2>(); + EXPECT_FALSE(is_supported); +} diff --git a/test/grouped_convnd_bwd_weight/CMakeLists.txt b/test/grouped_convnd_bwd_weight/CMakeLists.txt index da554f677aae1e6407ee7fc746f08d6afd55e4e5..f1a3f04b77a8ef4aceac5a6dc233419c8cc26c0c 100644 --- a/test/grouped_convnd_bwd_weight/CMakeLists.txt +++ b/test/grouped_convnd_bwd_weight/CMakeLists.txt @@ -1,4 +1,11 @@ -if(GPU_TARGETS MATCHES "gfx908" OR GPU_TARGETS MATCHES "gfx90a" OR GPU_TARGETS MATCHES "gfx940") - add_gtest_executable(test_grouped_convnd_bwd_weight grouped_convnd_bwd_weight.cpp) - target_link_libraries(test_grouped_convnd_bwd_weight PRIVATE utility device_grouped_conv1d_bwd_weight_instance device_grouped_conv2d_bwd_weight_instance device_grouped_conv3d_bwd_weight_instance) -endif() \ No newline at end of file +list(APPEND gpu_list gfx908 gfx90a gfx940 gfx941 gfx942) +set(target 0) +foreach(gpu IN LISTS GPU_TARGETS) + if(gpu IN_LIST gpu_list AND target EQUAL 0) + add_gtest_executable(test_grouped_convnd_bwd_weight test_grouped_convnd_bwd_weight.cpp) + target_link_libraries(test_grouped_convnd_bwd_weight PRIVATE utility device_grouped_conv1d_bwd_weight_instance device_grouped_conv2d_bwd_weight_instance device_grouped_conv3d_bwd_weight_instance) + add_gtest_executable(test_grouped_convnd_bwd_weight_interface test_grouped_convnd_bwd_weight_interface.cpp) + target_link_libraries(test_grouped_convnd_bwd_weight_interface PRIVATE utility device_grouped_conv1d_bwd_weight_instance device_grouped_conv2d_bwd_weight_instance device_grouped_conv3d_bwd_weight_instance) + set(target 1) + endif() +endforeach() \ No newline at end of file diff --git a/test/grouped_convnd_bwd_weight/grouped_convnd_bwd_weight.cpp b/test/grouped_convnd_bwd_weight/grouped_convnd_bwd_weight.cpp deleted file mode 100644 index 75f934cc06e23a3d9824e0df00c936333608a280..0000000000000000000000000000000000000000 --- a/test/grouped_convnd_bwd_weight/grouped_convnd_bwd_weight.cpp +++ /dev/null @@ -1,91 +0,0 @@ -// SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. - -#include -#include -#include -#include -#include - -#include - -#include "profiler/profile_grouped_conv_bwd_weight_impl.hpp" - -template -class TestGroupedConvndBwdWeight : public ::testing::Test -{ - protected: - using DataType = std::tuple_element_t<0, Tuple>; - std::vector conv_params; - ck::index_t split_k{2}; - - template - void Run() - { - for(auto& param : conv_params) - { - bool pass; - EXPECT_FALSE(conv_params.empty()); - pass = ck::profiler::profile_grouped_conv_bwd_weight_impl< - NDimSpatial, - ck::tuple_element_t>, - ck::tuple_element_t>, - ck::tuple_element_t>, - DataType, - DataType, - DataType>(true, // do_verification - 1, // init_method: integer value - false, // do_log - false, // time_kernel - param, - split_k); - EXPECT_TRUE(pass); - } - } -}; - -using KernelTypes = - ::testing::Types, std::tuple, std::tuple>; -TYPED_TEST_SUITE(TestGroupedConvndBwdWeight, KernelTypes); - -TYPED_TEST(TestGroupedConvndBwdWeight, Test1D) -{ - this->conv_params.clear(); - this->conv_params.push_back({1, 2, 128, 128, 256, {1}, {14}, {2}, {1}, {0}, {0}}); - this->conv_params.push_back({1, 2, 32, 128, 256, {3}, {28}, {1}, {1}, {1}, {1}}); - this->conv_params.push_back({1, 2, 128, 128, 256, {1}, {3}, {1}, {1}, {0}, {0}}); - this->template Run<1>(); -} - -TYPED_TEST(TestGroupedConvndBwdWeight, Test2D) -{ - this->conv_params.clear(); - this->conv_params.push_back( - {2, 2, 64, 128, 256, {1, 1}, {7, 7}, {2, 2}, {1, 1}, {0, 0}, {0, 0}}); - this->conv_params.push_back( - {2, 2, 4, 128, 256, {3, 3}, {14, 14}, {1, 1}, {1, 1}, {1, 1}, {1, 1}}); - this->conv_params.push_back( - {2, 2, 128, 128, 256, {1, 1}, {3, 3}, {1, 1}, {1, 1}, {0, 0}, {0, 0}}); - this->template Run<2>(); -} - -TYPED_TEST(TestGroupedConvndBwdWeight, Test3D) -{ - this->conv_params.clear(); - this->conv_params.push_back( - {3, 2, 16, 128, 256, {1, 1, 1}, {7, 7, 7}, {2, 2, 2}, {1, 1, 1}, {0, 0, 0}, {0, 0, 0}}); - this->conv_params.push_back( - {3, 2, 2, 128, 256, {3, 3, 3}, {14, 14, 3}, {1, 1, 1}, {1, 1, 1}, {1, 1, 1}, {1, 1, 1}}); - this->conv_params.push_back( - {3, 2, 32, 128, 256, {1, 1, 1}, {3, 3, 3}, {1, 1, 1}, {1, 1, 1}, {0, 0, 0}, {0, 0, 0}}); - this->template Run<3>(); -} diff --git a/test/grouped_convnd_bwd_weight/test_grouped_convnd_bwd_weight.cpp b/test/grouped_convnd_bwd_weight/test_grouped_convnd_bwd_weight.cpp new file mode 100644 index 0000000000000000000000000000000000000000..ae15f6c0c93ba8b939d942204951fc29bd785995 --- /dev/null +++ b/test/grouped_convnd_bwd_weight/test_grouped_convnd_bwd_weight.cpp @@ -0,0 +1,140 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#include +#include +#include +#include +#include + +#include + +#include "ck/utility/common_header.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" + +#include "profiler/profile_grouped_conv_bwd_weight_impl.hpp" + +template +class TestGroupedConvndBwdWeight : public ::testing::Test +{ + protected: + using InDataType = std::tuple_element_t<0, Tuple>; + using WeiDataType = std::tuple_element_t<1, Tuple>; + using OutDataType = std::tuple_element_t<2, Tuple>; + using InLayout = std::tuple_element_t<3, Tuple>; + using WeiLayout = std::tuple_element_t<4, Tuple>; + using OutLayout = std::tuple_element_t<5, Tuple>; + using NDimSpatial = std::tuple_element_t<6, Tuple>; + + std::vector conv_params; + ck::index_t split_k{2}; + + void Run() + { + EXPECT_FALSE(conv_params.empty()); + bool pass = true; + + for(auto& param : conv_params) + { + pass = pass && ck::profiler::profile_grouped_conv_bwd_weight_impl( + true, // do_verification + 1, // init_method: integer value + false, // do_log + false, // time_kernel + param, + split_k); + } + EXPECT_TRUE(pass); + } +}; + +template +class TestGroupedConvndBwdWeight1d : public TestGroupedConvndBwdWeight +{ +}; + +template +class TestGroupedConvndBwdWeight2d : public TestGroupedConvndBwdWeight +{ +}; + +template +class TestGroupedConvndBwdWeight3d : public TestGroupedConvndBwdWeight +{ +}; + +using namespace ck::tensor_layout::convolution; + +using KernelTypes1d = ::testing::Types< + std::tuple>, + std::tuple>, + std::tuple>>; +using KernelTypes2d = ::testing::Types< + std::tuple>, + std::tuple>, + std::tuple>, + std::tuple>, + std::tuple>, + std::tuple>>; +using KernelTypes3d = ::testing::Types< + std::tuple>, + std::tuple>, + std::tuple>, + std::tuple>, + std::tuple>, + std::tuple>>; + +TYPED_TEST_SUITE(TestGroupedConvndBwdWeight1d, KernelTypes1d); +TYPED_TEST_SUITE(TestGroupedConvndBwdWeight2d, KernelTypes2d); +TYPED_TEST_SUITE(TestGroupedConvndBwdWeight3d, KernelTypes3d); + +TYPED_TEST(TestGroupedConvndBwdWeight1d, Test1D) +{ + this->conv_params.clear(); + this->conv_params.push_back({1, 2, 128, 128, 256, {1}, {14}, {2}, {1}, {0}, {0}}); + this->conv_params.push_back({1, 2, 32, 128, 256, {3}, {28}, {1}, {1}, {1}, {1}}); + this->conv_params.push_back({1, 2, 128, 128, 256, {1}, {3}, {1}, {1}, {0}, {0}}); + this->conv_params.push_back({1, 1, 1, 1, 32, {3}, {32}, {1}, {1}, {1}, {1}}); + this->conv_params.push_back({1, 1, 1, 64, 3, {3}, {32}, {1}, {1}, {1}, {1}}); + this->conv_params.push_back({1, 1, 1, 1, 1, {3}, {32}, {1}, {1}, {1}, {1}}); + this->Run(); +} + +TYPED_TEST(TestGroupedConvndBwdWeight2d, Test2D) +{ + this->conv_params.clear(); + this->conv_params.push_back( + {2, 2, 64, 128, 256, {1, 1}, {7, 7}, {2, 2}, {1, 1}, {0, 0}, {0, 0}}); + this->conv_params.push_back( + {2, 2, 4, 128, 256, {3, 3}, {14, 14}, {1, 1}, {1, 1}, {1, 1}, {1, 1}}); + this->conv_params.push_back( + {2, 2, 128, 128, 256, {1, 1}, {3, 3}, {1, 1}, {1, 1}, {0, 0}, {0, 0}}); + this->conv_params.push_back({2, 1, 1, 1, 32, {3, 3}, {32, 32}, {1, 1}, {1, 1}, {1, 1}, {1, 1}}); + this->conv_params.push_back({2, 1, 1, 64, 3, {3, 3}, {32, 32}, {1, 1}, {1, 1}, {1, 1}, {1, 1}}); + this->conv_params.push_back({2, 1, 1, 1, 1, {3, 3}, {32, 32}, {1, 1}, {1, 1}, {1, 1}, {1, 1}}); + this->Run(); +} + +TYPED_TEST(TestGroupedConvndBwdWeight3d, Test3D) +{ + this->conv_params.clear(); + this->conv_params.push_back( + {3, 2, 16, 128, 256, {1, 1, 1}, {7, 7, 7}, {2, 2, 2}, {1, 1, 1}, {0, 0, 0}, {0, 0, 0}}); + this->conv_params.push_back( + {3, 2, 2, 128, 256, {3, 3, 3}, {14, 14, 3}, {1, 1, 1}, {1, 1, 1}, {1, 1, 1}, {1, 1, 1}}); + this->conv_params.push_back( + {3, 2, 32, 128, 256, {1, 1, 1}, {3, 3, 3}, {1, 1, 1}, {1, 1, 1}, {0, 0, 0}, {0, 0, 0}}); + this->conv_params.push_back( + {3, 1, 1, 1, 32, {3, 3, 3}, {32, 32, 32}, {1, 1, 1}, {1, 1, 1}, {1, 1, 1}, {1, 1, 1}}); + this->conv_params.push_back( + {3, 1, 1, 64, 3, {3, 3, 3}, {32, 32, 32}, {1, 1, 1}, {1, 1, 1}, {1, 1, 1}, {1, 1, 1}}); + this->conv_params.push_back( + {3, 1, 1, 1, 1, {3, 3, 3}, {32, 32, 32}, {1, 1, 1}, {1, 1, 1}, {1, 1, 1}, {1, 1, 1}}); + this->Run(); +} diff --git a/test/grouped_convnd_bwd_weight/test_grouped_convnd_bwd_weight_interface.cpp b/test/grouped_convnd_bwd_weight/test_grouped_convnd_bwd_weight_interface.cpp new file mode 100644 index 0000000000000000000000000000000000000000..cfbf13f00ed92cbc79d6d7a5a5ee1bd73c4ff193 --- /dev/null +++ b/test/grouped_convnd_bwd_weight/test_grouped_convnd_bwd_weight_interface.cpp @@ -0,0 +1,179 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#include +#include +#include +#include +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/convolution_backward_weight_specialization.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_xdl_cshuffle.hpp" + +#include "ck/library/utility/convolution_parameter.hpp" +#include "ck/library/utility/algorithm.hpp" +#include "ck/library/utility/convolution_host_tensor_descriptor_helper.hpp" + +#include + +using F16 = ck::half_t; +using F32 = float; +using PassThrough = ck::tensor_operation::element_wise::PassThrough; + +template +using S = ck::Sequence; +using ConvolutionBackwardWeightSpecialization = + ck::tensor_operation::device::ConvolutionBackwardWeightSpecialization; + +static constexpr auto ConvBwdWeightDefault = ConvolutionBackwardWeightSpecialization::Default; +static constexpr auto Filter1x1Stride1Pad0 = + ConvolutionBackwardWeightSpecialization::Filter1x1Stride1Pad0; + +template +class TestGroupedConvndBwdWeight : public ::testing::Test +{ + protected: + static constexpr ck::index_t NDimSpatial = 2; + + using InLayout = std::tuple_element_t<2, Tuple>; + using WeiLayout = std::tuple_element_t<1, Tuple>; + using OutLayout = std::tuple_element_t<0, Tuple>; + + // clang-format off + using GroupedConvBwdWeightDeviceInstance = ck::tensor_operation::device::DeviceGroupedConvBwdWeight_Xdl_CShuffle + //##########| Num| InLayout| WeiLayout| OutLayout| InData| WeiData| OutData| AccData| In| Wei| Out| ConvBackward| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransfer| CBlockTransfer| + //##########| Dim| | | | Type| Type| Type| Type| Elementwise| Elementwise| Elementwise| Weight| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| ClusterLengths| ScalarPerVector| + //##########| Spatial| | | | | | | | Operation| Operation| Operation| Specialization| | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| MBlock_MPerBlock| NWaveNPerXdl| + //##########| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | NBlock_NPerBlock| | + < NDimSpatial, InLayout, WeiLayout,OutLayout, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 128, 32, 128, 4, 8, 32, 32, 1, 2, S<1, 4, 4, 8>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 1, true, S<1, 4, 16, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, 1, 1, S<1, 32, 1, 4>, 8>; + // clang-format on + + ck::utils::conv::ConvParam conv_param; + ck::index_t split_k{2}; + + template + bool Run() + { + + const auto in_g_n_c_wis_desc = + ck::utils::conv::make_input_host_tensor_descriptor_g_n_c_wis_packed( + conv_param); + + const auto wei_g_k_c_xs_desc = + ck::utils::conv::make_weight_host_tensor_descriptor_g_k_c_xs_packed( + conv_param); + + const auto out_g_n_k_wos_desc = + ck::utils::conv::make_output_host_tensor_descriptor_g_n_k_wos_packed( + conv_param); + + std::array input_lengths{}; + std::array filter_lengths{}; + std::array output_lengths{}; + std::array input_strides{}; + std::array weights_strides{}; + std::array output_strides{}; + std::array conv_filter_strides{}; + std::array conv_filter_dilations{}; + std::array input_left_pads{}; + std::array input_right_pads{}; + + auto range_copy = [](const auto& from, auto to) { std::copy(begin(from), end(from), to); }; + + range_copy(in_g_n_c_wis_desc.GetLengths(), begin(input_lengths)); + range_copy(in_g_n_c_wis_desc.GetStrides(), begin(input_strides)); + range_copy(wei_g_k_c_xs_desc.GetLengths(), begin(filter_lengths)); + range_copy(wei_g_k_c_xs_desc.GetStrides(), begin(weights_strides)); + range_copy(out_g_n_k_wos_desc.GetLengths(), begin(output_lengths)); + range_copy(out_g_n_k_wos_desc.GetStrides(), begin(output_strides)); + range_copy(conv_param.conv_filter_strides_, begin(conv_filter_strides)); + range_copy(conv_param.conv_filter_dilations_, begin(conv_filter_dilations)); + range_copy(conv_param.input_left_pads_, begin(input_left_pads)); + range_copy(conv_param.input_right_pads_, begin(input_right_pads)); + + auto conv = GroupedConvBwdWeightDeviceInstance{}; + + auto argument = conv.MakeArgument(nullptr, + nullptr, + nullptr, + input_lengths, + input_strides, + filter_lengths, + weights_strides, + output_lengths, + output_strides, + conv_filter_strides, + conv_filter_dilations, + input_left_pads, + input_right_pads, + PassThrough{}, + PassThrough{}, + PassThrough{}, + split_k); + return conv.IsSupportedArgument(argument); + } +}; + +using GNHWC = ck::tensor_layout::convolution::GNHWC; +using NHWGC = ck::tensor_layout::convolution::NHWGC; + +using GKYXC = ck::tensor_layout::convolution::GKYXC; + +using GNHWK = ck::tensor_layout::convolution::GNHWK; +using NHWGK = ck::tensor_layout::convolution::NHWGK; + +using KernelTypes = + ::testing::Types, std::tuple>; + +template +class TestGroupedConvndBwdWeightDefault + : public TestGroupedConvndBwdWeight +{ +}; + +template +class TestGroupedConvndBwdWeightFilter1x1 + : public TestGroupedConvndBwdWeight +{ +}; + +TYPED_TEST_SUITE(TestGroupedConvndBwdWeightDefault, KernelTypes); +TYPED_TEST_SUITE(TestGroupedConvndBwdWeightFilter1x1, KernelTypes); + +TYPED_TEST(TestGroupedConvndBwdWeightFilter1x1, SpecializationCheck) +{ + // Check filter 3,3 instead of 1,1 + this->conv_param = {2, 2, 4, 192, 192, {3, 3}, {28, 28}, {1, 1}, {1, 1}, {0, 0}, {0, 0}}; + bool is_supported = this->template Run<2>(); + EXPECT_FALSE(is_supported); + + // Check strides 2,2 instead of 1,1 + this->conv_param = {2, 2, 4, 192, 192, {1, 1}, {28, 28}, {2, 2}, {1, 1}, {0, 0}, {0, 0}}; + is_supported = this->template Run<2>(); + EXPECT_FALSE(is_supported); + + // Check with pad + this->conv_param = {2, 2, 4, 192, 192, {1, 1}, {28, 28}, {1, 1}, {1, 1}, {1, 1}, {1, 1}}; + is_supported = this->template Run<2>(); + EXPECT_FALSE(is_supported); + + // Supported version + this->conv_param = {2, 2, 128, 128, 256, {1, 1}, {3, 3}, {1, 1}, {1, 1}, {0, 0}, {0, 0}}; + is_supported = this->template Run<2>(); + EXPECT_TRUE(is_supported); +} + +TYPED_TEST(TestGroupedConvndBwdWeightDefault, VectorLoadCheck) +{ + // vector load for A + this->conv_param = {2, 2, 128, 129, 256, {1, 1}, {7, 7}, {2, 2}, {1, 1}, {0, 0}, {0, 0}}; + bool is_supported = this->template Run<2>(); + EXPECT_FALSE(is_supported); + // vector load for B, E, Ds + this->conv_param = {2, 2, 128, 128, 257, {1, 1}, {7, 7}, {2, 2}, {1, 1}, {0, 0}, {0, 0}}; + is_supported = this->template Run<2>(); + EXPECT_FALSE(is_supported); +} diff --git a/test/grouped_convnd_fwd/grouped_convnd_fwd.cpp b/test/grouped_convnd_fwd/grouped_convnd_fwd.cpp index 6df7f9969cb5d802ae28d0a7fdf03221c6c91c4f..c856255ea38afc204cddace383bec714575bead2 100644 --- a/test/grouped_convnd_fwd/grouped_convnd_fwd.cpp +++ b/test/grouped_convnd_fwd/grouped_convnd_fwd.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #include #include @@ -22,6 +22,8 @@ TEST_F(TestGroupedConvNdFwd, GroupedConv1dFwdGNWC) conv_params.push_back({1, 2, 128, 128, 256, {1}, {14}, {2}, {1}, {0}, {0}}); conv_params.push_back({1, 2, 128, 128, 256, {3}, {28}, {1}, {1}, {1}, {1}}); conv_params.push_back({1, 2, 128, 128, 256, {1}, {3}, {1}, {1}, {0}, {0}}); + conv_params.push_back({1, 1, 1, 1, 32, {3}, {32}, {1}, {1}, {1}, {1}}); + conv_params.push_back({1, 1, 1, 64, 3, {3}, {32}, {1}, {1}, {1}, {1}}); for(auto& param : conv_params) { @@ -96,6 +98,9 @@ TEST_F(TestGroupedConvNdFwd, GroupedConv2dFwdGNHWC) conv_params.push_back({2, 2, 128, 128, 256, {1, 1}, {7, 7}, {2, 2}, {1, 1}, {0, 0}, {0, 0}}); conv_params.push_back({2, 2, 128, 128, 256, {3, 3}, {14, 14}, {1, 1}, {1, 1}, {1, 1}, {1, 1}}); conv_params.push_back({2, 2, 128, 128, 256, {1, 1}, {3, 3}, {1, 1}, {1, 1}, {0, 0}, {0, 0}}); + conv_params.push_back({2, 1, 1, 1, 32, {3, 3}, {32, 32}, {1, 1}, {1, 1}, {1, 1}, {1, 1}}); + conv_params.push_back({2, 1, 1, 64, 3, {3, 3}, {32, 32}, {1, 1}, {1, 1}, {1, 1}, {1, 1}}); + conv_params.push_back({2, 1, 1, 1, 1, {3, 3}, {32, 32}, {1, 1}, {1, 1}, {1, 1}, {1, 1}}); for(auto& param : conv_params) { @@ -173,6 +178,12 @@ TEST_F(TestGroupedConvNdFwd, GroupedConv3dFwdGNDHWC) {3, 2, 128, 128, 256, {3, 3, 3}, {14, 14, 3}, {1, 1, 1}, {1, 1, 1}, {1, 1, 1}, {1, 1, 1}}); conv_params.push_back( {3, 2, 128, 128, 256, {1, 1, 1}, {3, 3, 3}, {1, 1, 1}, {1, 1, 1}, {0, 0, 0}, {0, 0, 0}}); + conv_params.push_back( + {3, 1, 1, 1, 32, {3, 3, 3}, {32, 32, 32}, {1, 1, 1}, {1, 1, 1}, {1, 1, 1}, {1, 1, 1}}); + this->conv_params.push_back( + {3, 1, 1, 64, 3, {3, 3, 3}, {32, 32, 32}, {1, 1, 1}, {1, 1, 1}, {1, 1, 1}, {1, 1, 1}}); + conv_params.push_back( + {3, 1, 1, 1, 1, {3, 3, 3}, {32, 32, 32}, {1, 1, 1}, {1, 1, 1}, {1, 1, 1}, {1, 1, 1}}); for(auto& param : conv_params) { @@ -247,6 +258,9 @@ TEST_F(TestGroupedConvNdFwd, GroupedConv2dFwdNHWGC) conv_params.push_back({2, 2, 128, 128, 256, {1, 1}, {7, 7}, {2, 2}, {1, 1}, {0, 0}, {0, 0}}); conv_params.push_back({2, 2, 128, 128, 256, {3, 3}, {14, 14}, {1, 1}, {1, 1}, {1, 1}, {1, 1}}); conv_params.push_back({2, 2, 128, 128, 256, {1, 1}, {3, 3}, {1, 1}, {1, 1}, {0, 0}, {0, 0}}); + conv_params.push_back({2, 1, 1, 1, 32, {3, 3}, {32, 32}, {1, 1}, {1, 1}, {1, 1}, {1, 1}}); + conv_params.push_back({2, 1, 1, 64, 3, {3, 3}, {32, 32}, {1, 1}, {1, 1}, {1, 1}, {1, 1}}); + conv_params.push_back({2, 1, 1, 1, 1, {3, 3}, {32, 32}, {1, 1}, {1, 1}, {1, 1}, {1, 1}}); for(auto& param : conv_params) { @@ -255,7 +269,7 @@ TEST_F(TestGroupedConvNdFwd, GroupedConv2dFwdNHWGC) // fp16 pass = ck::profiler::profile_grouped_conv_fwd_impl<2, ck::tensor_layout::convolution::NHWGC, - ck::tensor_layout::convolution::KYXGC, + ck::tensor_layout::convolution::GKYXC, ck::tensor_layout::convolution::NHWGK, ck::half_t, ck::half_t, diff --git a/test/grouped_gemm/CMakeLists.txt b/test/grouped_gemm/CMakeLists.txt index 40f634d8b3819a40898041242462be79b875b9b1..476d953ed8c302ca0a7b50488a1fb34b1ac147fc 100644 --- a/test/grouped_gemm/CMakeLists.txt +++ b/test/grouped_gemm/CMakeLists.txt @@ -1,4 +1,8 @@ -if(GPU_TARGETS MATCHES "gfx908" OR GPU_TARGETS MATCHES "gfx90a" OR GPU_TARGETS MATCHES "gfx940") +if(DTYPES MATCHES "fp16" OR NOT DEFINED DTYPES) +list(APPEND gpu_list gfx908 gfx90a gfx940 gfx941 gfx942) +set(target 0) +foreach(gpu IN LISTS GPU_TARGETS) + if(gpu IN_LIST gpu_list AND target EQUAL 0) add_custom_target(test_grouped_gemm) add_gtest_executable(test_grouped_gemm_splitk test_grouped_gemm_splitk.cpp) add_gtest_executable(test_grouped_gemm_interface test_grouped_gemm_interface.cpp) @@ -6,4 +10,7 @@ if(GPU_TARGETS MATCHES "gfx908" OR GPU_TARGETS MATCHES "gfx90a" OR GPU_TARGETS M target_link_libraries(test_grouped_gemm_interface PRIVATE utility device_grouped_gemm_instance) add_dependencies(test_grouped_gemm test_grouped_gemm_splitk test_grouped_gemm_interface) + set(target 1) + endif() +endforeach() endif() diff --git a/test/grouped_gemm/test_grouped_gemm_interface.cpp b/test/grouped_gemm/test_grouped_gemm_interface.cpp index ffa8840fc7d8ba6b8a1c1a39c62edc6edae58064..6ff3a787e734057d84036b894df89e3884bc7caf 100644 --- a/test/grouped_gemm/test_grouped_gemm_interface.cpp +++ b/test/grouped_gemm/test_grouped_gemm_interface.cpp @@ -108,7 +108,7 @@ TEST_F(TestGGemmSplitKInterface_MKNKMN, KLoops) // kloops % 2 Ks = std::vector{256, 512, 320, 768}; - EXPECT_FALSE( + EXPECT_TRUE( DefaultGGemmInstance{}.IsSupported(Ms, Ns, Ks, StrideAs, StrideBs, StrideCs, kbatch)); // Not all gemms have same value for main_k0_block_loop! diff --git a/test/grouped_gemm/test_grouped_gemm_util.hpp b/test/grouped_gemm/test_grouped_gemm_util.hpp index b61118b5120e577defe9453c741875cafe33d168..04b31dcc912a64b46f4d04653522cd8a3992902f 100644 --- a/test/grouped_gemm/test_grouped_gemm_util.hpp +++ b/test/grouped_gemm/test_grouped_gemm_util.hpp @@ -147,14 +147,14 @@ struct DeviceGroupedGemmSplitkInstanceWrapper 32, 4, 2, - S<1, 4, 32, 1>, + S<1, 4, 16, 1>, ABlockTransferThreadClusterArrageOrder, ABlockTransferSrcAccessOrder, ABlockTransferSrcVectorDim::value, ABlockTransferSrcScalarPerVector, ABlockTransferDstScalarPerVector_K1::value, ABlockLdsAddExtraM::value, - S<1, 4, 32, 1>, + S<1, 4, 16, 1>, BBlockTransferThreadClusterArrageOrder, BBlockTransferSrcAccessOrder, BBlockTransferSrcVectorDim::value, diff --git a/test/image_to_column/CMakeLists.txt b/test/image_to_column/CMakeLists.txt new file mode 100644 index 0000000000000000000000000000000000000000..0feb827b552b8331876087ca4b61a6a2929bffff --- /dev/null +++ b/test/image_to_column/CMakeLists.txt @@ -0,0 +1,4 @@ +add_gtest_executable(test_image_to_column test_image_to_column.cpp) +target_link_libraries(test_image_to_column PRIVATE utility device_image_to_column_instance) +add_gtest_executable(test_image_to_column_interface test_image_to_column_interface.cpp) +target_link_libraries(test_image_to_column_interface PRIVATE utility) diff --git a/test/image_to_column/test_image_to_column.cpp b/test/image_to_column/test_image_to_column.cpp new file mode 100644 index 0000000000000000000000000000000000000000..0b17cac2d07d8859a3c92cf2fc45780105e76b29 --- /dev/null +++ b/test/image_to_column/test_image_to_column.cpp @@ -0,0 +1,121 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#include +#include +#include +#include +#include + +#include + +#include "profiler/profile_image_to_column_impl.hpp" + +template +class TestImageToColumn : public ::testing::Test +{ + protected: + using InDataType = std::tuple_element_t<0, Tuple>; + using OutDataType = std::tuple_element_t<1, Tuple>; + using InLayout = std::tuple_element_t<2, Tuple>; + + std::vector conv_params; + + template + void Run() + { + EXPECT_FALSE(conv_params.empty()); + bool pass = true; + for(auto& param : conv_params) + { + pass = pass && ck::profiler::profile_image_to_column_impl( + true, // do_verification + 1, // init_method: integer value + false, // do_log + false, // time_kernel + param); + } + EXPECT_TRUE(pass); + } +}; + +using namespace ck::tensor_layout::convolution; + +using KernelTypes1d = ::testing::Types, + std::tuple, + std::tuple, + std::tuple>; + +using KernelTypes2d = ::testing::Types, + std::tuple, + std::tuple, + std::tuple>; + +using KernelTypes3d = ::testing::Types, + std::tuple, + std::tuple, + std::tuple>; + +template +class TestImageToColumn1d : public TestImageToColumn +{ +}; + +template +class TestImageToColumn2d : public TestImageToColumn +{ +}; + +template +class TestImageToColumn3d : public TestImageToColumn +{ +}; + +TYPED_TEST_SUITE(TestImageToColumn1d, KernelTypes1d); +TYPED_TEST_SUITE(TestImageToColumn2d, KernelTypes2d); +TYPED_TEST_SUITE(TestImageToColumn3d, KernelTypes3d); + +TYPED_TEST(TestImageToColumn1d, Test1D) +{ + this->conv_params.clear(); + + this->conv_params.push_back({1, 1, 4, 1, 192, {3}, {28}, {1}, {1}, {1}, {1}}); + this->conv_params.push_back({1, 1, 64, 1, 64, {3}, {14}, {1}, {1}, {1}, {1}}); + this->conv_params.push_back({1, 1, 64, 1, 64, {1}, {7}, {2}, {1}, {0}, {0}}); + this->conv_params.push_back({1, 1, 64, 1, 64, {1}, {3}, {1}, {1}, {0}, {0}}); + // ScalarPerVector should be 1 + this->conv_params.push_back({1, 1, 4, 1, 1, {3}, {28}, {1}, {1}, {1}, {1}}); + // stride != 1 + this->conv_params.push_back({1, 1, 1, 1, 4, {3}, {28}, {2}, {1}, {1}, {1}}); + // dilation != 1 + this->conv_params.push_back({1, 1, 1, 1, 4, {3}, {28}, {1}, {2}, {1}, {1}}); + this->template Run<1>(); +} + +TYPED_TEST(TestImageToColumn2d, Test2D) +{ + this->conv_params.clear(); + + this->conv_params.push_back( + {2, 1, 4, 1, 192, {3, 3}, {28, 28}, {1, 1}, {1, 1}, {1, 1}, {1, 1}}); + this->conv_params.push_back( + {2, 1, 64, 1, 64, {3, 3}, {14, 14}, {1, 1}, {1, 1}, {1, 1}, {1, 1}}); + this->conv_params.push_back({2, 1, 64, 1, 64, {1, 1}, {7, 7}, {2, 2}, {1, 1}, {0, 0}, {0, 0}}); + this->conv_params.push_back({2, 1, 64, 1, 64, {1, 1}, {3, 3}, {1, 1}, {1, 1}, {0, 0}, {0, 0}}); + this->template Run<2>(); +} + +TYPED_TEST(TestImageToColumn3d, Test3D) +{ + this->conv_params.clear(); + this->conv_params.push_back( + {3, 1, 16, 1, 64, {1, 1, 1}, {7, 7, 7}, {2, 2, 2}, {1, 1, 1}, {0, 0, 0}, {0, 0, 0}}); + this->conv_params.push_back( + {3, 1, 2, 1, 64, {3, 3, 3}, {14, 14, 3}, {1, 1, 1}, {1, 1, 1}, {1, 1, 1}, {1, 1, 1}}); + this->conv_params.push_back( + {3, 1, 32, 1, 64, {1, 1, 1}, {3, 3, 3}, {1, 1, 1}, {1, 1, 1}, {0, 0, 0}, {0, 0, 0}}); + this->template Run<3>(); +} diff --git a/test/image_to_column/test_image_to_column_interface.cpp b/test/image_to_column/test_image_to_column_interface.cpp new file mode 100644 index 0000000000000000000000000000000000000000..ea8b9632e1cc5dc346f6a4020b860213ec8d2dab --- /dev/null +++ b/test/image_to_column/test_image_to_column_interface.cpp @@ -0,0 +1,196 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#include +#include +#include +#include +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_image_to_column_impl.hpp" + +#include "ck/library/utility/convolution_parameter.hpp" +#include "ck/library/utility/algorithm.hpp" +#include "ck/library/utility/convolution_host_tensor_descriptor_helper.hpp" + +#include + +using DataType = float; +using InLayout = ck::tensor_layout::convolution::GNWC; + +template +using S = ck::Sequence; + +template +class TestImageToColumnInterface : public ::testing::Test +{ + protected: + static constexpr ck::index_t NDimSpatial = 1; + + // clang-format off + using DeviceImgToColInstance = ck::tensor_operation::device::DeviceImageToColumnImpl + //#####################| Num| InLayout| InDataType| OutDataType| Block| MPer| KPer| Thread| Scalar| + //#####################| Dim| | | | Size| Block| Block| Cluster| Per| + //#####################| Spatial| | | | | | | Lengths| Vector| + //#####################| | | | | | | | | | + < NDimSpatial, InLayout, DataType, DataType, 256, 128, 128, S<16, 16>,ScalarPerVector>; + // clang-format on + + ck::utils::conv::ConvParam conv_param; + + bool Run() + { + + const auto N = conv_param.N_; + const auto C = conv_param.C_; + const auto FakeC = + conv_param.C_ / 2; // Fake C to simulate the behavior that C is not packed + + const ck::index_t NDoHoWo = + N * + ck::accumulate_n( + conv_param.output_spatial_lengths_.begin(), NDimSpatial, 1, std::multiplies<>()); + const ck::index_t CZYX = + C * + ck::accumulate_n( + conv_param.filter_spatial_lengths_.begin(), NDimSpatial, 1, std::multiplies<>()); + + const auto in_desc = + ck::utils::conv::make_input_host_tensor_descriptor_g_n_c_wis_packed( + conv_param); + const auto out_desc = HostTensorDescriptor({NDoHoWo, CZYX}); + + std::array input_spatial_lengths{}; + std::array filter_spatial_lengths{}; + std::array output_spatial_lengths{}; + std::array input_g_n_c_wis_strides{}; + std::array output_m_k_strides{}; + std::array conv_filter_strides{}; + std::array conv_filter_dilations{}; + std::array input_left_pads{}; + std::array input_right_pads{}; + + auto copy = [](const auto& x, auto& y) { std::copy(x.begin(), x.end(), y.begin()); }; + + copy(conv_param.input_spatial_lengths_, input_spatial_lengths); + copy(conv_param.filter_spatial_lengths_, filter_spatial_lengths); + copy(conv_param.output_spatial_lengths_, output_spatial_lengths); + copy(in_desc.GetStrides(), input_g_n_c_wis_strides); + copy(out_desc.GetStrides(), output_m_k_strides); + copy(conv_param.conv_filter_strides_, conv_filter_strides); + copy(conv_param.conv_filter_dilations_, conv_filter_dilations); + copy(conv_param.input_left_pads_, input_left_pads); + copy(conv_param.input_right_pads_, input_right_pads); + + auto img2col = DeviceImgToColInstance{}; + auto argument = img2col.MakeArgument(nullptr, + nullptr, + N, + IsCPacked ? C : FakeC, + input_spatial_lengths, + filter_spatial_lengths, + output_spatial_lengths, + input_g_n_c_wis_strides, + output_m_k_strides, + conv_filter_strides, + conv_filter_dilations, + input_left_pads, + input_right_pads); + + return img2col.IsSupportedArgument(argument); + } +}; + +class TestImageToColumnInterface1ScalarPerVector : public TestImageToColumnInterface<1, true> +{ +}; + +class TestImageToColumnInterface4ScalarPerVector : public TestImageToColumnInterface<4, true> +{ +}; + +class TestImageToColumnInterface4ScalarPerVectorFakeC : public TestImageToColumnInterface<4, false> +{ +}; + +TEST_F(TestImageToColumnInterface1ScalarPerVector, X1ScalarPerVector) +{ + // vector load C * X % ScalarPerVector + this->conv_param = {1, 1, 1, 1, 1, {3}, {3}, {1}, {1}, {0}, {0}}; + bool is_supported = this->Run(); + EXPECT_TRUE(is_supported); + // vector load C * left_pad_x % ScalarPerVector + this->conv_param = {1, 1, 1, 1, 1, {4}, {3}, {1}, {1}, {3}, {0}}; + is_supported = this->Run(); + EXPECT_TRUE(is_supported); + // vector load C * right_pad_x % ScalarPerVector + this->conv_param = {1, 1, 1, 1, 1, {4}, {3}, {1}, {1}, {0}, {3}}; + is_supported = this->Run(); + EXPECT_TRUE(is_supported); + + // vector load C % ScalarPerVector, right_pad and stride + this->conv_param = {1, 1, 1, 1, 1, {4}, {3}, {2}, {1}, {0}, {3}}; + is_supported = this->Run(); + EXPECT_TRUE(is_supported); + // vector load C % ScalarPerVector, left_pad and stride + this->conv_param = {1, 1, 1, 1, 1, {4}, {3}, {2}, {1}, {3}, {0}}; + is_supported = this->Run(); + EXPECT_TRUE(is_supported); + // vector load C % ScalarPerVector, dilation + this->conv_param = {1, 1, 1, 1, 1, {4}, {3}, {1}, {2}, {0}, {0}}; + is_supported = this->Run(); + EXPECT_TRUE(is_supported); + + // C = 4 + this->conv_param = {1, 1, 1, 1, 4, {3}, {3}, {1}, {1}, {3}, {3}}; + is_supported = this->Run(); + EXPECT_TRUE(is_supported); +} + +TEST_F(TestImageToColumnInterface4ScalarPerVector, X4ScalarPerVector) +{ + // vector load C * X % ScalarPerVector + this->conv_param = {1, 1, 1, 1, 1, {3}, {3}, {1}, {1}, {0}, {0}}; + bool is_supported = this->Run(); + EXPECT_FALSE(is_supported); + // vector load C * left_pad_x % ScalarPerVector + this->conv_param = {1, 1, 1, 1, 1, {4}, {3}, {1}, {1}, {3}, {0}}; + is_supported = this->Run(); + EXPECT_FALSE(is_supported); + // vector load C * right_pad_x % ScalarPerVector + this->conv_param = {1, 1, 1, 1, 1, {4}, {3}, {1}, {1}, {0}, {3}}; + is_supported = this->Run(); + EXPECT_FALSE(is_supported); + + // vector load C % ScalarPerVector, right_pad and stride + this->conv_param = {1, 1, 1, 1, 1, {4}, {3}, {2}, {1}, {0}, {3}}; + is_supported = this->Run(); + EXPECT_FALSE(is_supported); + // vector load C % ScalarPerVector, left_pad and stride + this->conv_param = {1, 1, 1, 1, 1, {4}, {3}, {2}, {1}, {3}, {0}}; + is_supported = this->Run(); + EXPECT_FALSE(is_supported); + // vector load C % ScalarPerVector, dilation + this->conv_param = {1, 1, 1, 1, 1, {4}, {3}, {1}, {2}, {0}, {0}}; + is_supported = this->Run(); + EXPECT_FALSE(is_supported); + + // C = 4 + this->conv_param = {1, 1, 1, 1, 4, {3}, {3}, {1}, {1}, {3}, {3}}; + is_supported = this->Run(); + EXPECT_TRUE(is_supported); +} + +TEST_F(TestImageToColumnInterface4ScalarPerVectorFakeC, X4ScalarPerVectorFakeC) +{ + // C = 3 + this->conv_param = {1, 1, 1, 1, 3, {4}, {3}, {1}, {1}, {0}, {0}}; + bool is_supported = this->Run(); + EXPECT_FALSE(is_supported); + // C = 4 + this->conv_param = {1, 1, 1, 1, 8, {4}, {3}, {1}, {1}, {0}, {0}}; + is_supported = this->Run(); + EXPECT_TRUE(is_supported); +} diff --git a/test/magic_number_division/magic_number_division.cpp b/test/magic_number_division/magic_number_division.cpp index 680fddf1933611dc3088cb673ece0dba9e1f3911..253f21e91f80c96c081108b47b045c3902b36926 100644 --- a/test/magic_number_division/magic_number_division.cpp +++ b/test/magic_number_division/magic_number_division.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #include #include diff --git a/test/normalization/CMakeLists.txt b/test/normalization/CMakeLists.txt index a5d7fb2982c9dac7cc2da58a8f9f6be413cfbd18..2beda4dd74483e7462c28604ced976ba1474eff4 100644 --- a/test/normalization/CMakeLists.txt +++ b/test/normalization/CMakeLists.txt @@ -1,16 +1,19 @@ -add_custom_target(test_normalization) - -add_gtest_executable(test_layernorm2d_fp32 test_layernorm2d_fp32.cpp) -add_gtest_executable(test_layernorm2d_fp16 test_layernorm2d_fp16.cpp) -add_gtest_executable(test_groupnorm_fp16 test_groupnorm_fp16.cpp) -add_gtest_executable(test_groupnorm_fp32 test_groupnorm_fp32.cpp) - -target_link_libraries(test_layernorm2d_fp32 PRIVATE utility device_normalization_instance) -target_link_libraries(test_layernorm2d_fp16 PRIVATE utility device_normalization_instance) -target_link_libraries(test_groupnorm_fp16 PRIVATE utility device_normalization_instance) -target_link_libraries(test_groupnorm_fp32 PRIVATE utility device_normalization_instance) - -add_dependencies(test_normalization test_layernorm2d_fp32) -add_dependencies(test_normalization test_layernorm2d_fp16) -add_dependencies(test_normalization test_groupnorm_fp16) -add_dependencies(test_normalization test_groupnorm_fp32) +if(DTYPES MATCHES "fp16" OR DTYPES MATCHES "fp32" OR NOT DEFINED DTYPES) + add_custom_target(test_normalization) +endif() +if(DTYPES MATCHES "fp32" OR NOT DEFINED DTYPES) + add_gtest_executable(test_layernorm2d_fp32 test_layernorm2d_fp32.cpp) + add_gtest_executable(test_groupnorm_fp32 test_groupnorm_fp32.cpp) + target_link_libraries(test_layernorm2d_fp32 PRIVATE utility device_normalization_instance) + target_link_libraries(test_groupnorm_fp32 PRIVATE utility device_normalization_instance) + add_dependencies(test_normalization test_layernorm2d_fp32) + add_dependencies(test_normalization test_groupnorm_fp32) +endif() +if(DTYPES MATCHES "fp16" OR NOT DEFINED DTYPES) + add_gtest_executable(test_layernorm2d_fp16 test_layernorm2d_fp16.cpp) + add_gtest_executable(test_groupnorm_fp16 test_groupnorm_fp16.cpp) + target_link_libraries(test_layernorm2d_fp16 PRIVATE utility device_normalization_instance) + target_link_libraries(test_groupnorm_fp16 PRIVATE utility device_normalization_instance) + add_dependencies(test_normalization test_layernorm2d_fp16) + add_dependencies(test_normalization test_groupnorm_fp16) +endif() diff --git a/test/normalization/test_groupnorm_fp16.cpp b/test/normalization/test_groupnorm_fp16.cpp index 60d3b13959fc835e4d42244bfff5437561f3178a..325ea75fe5668897a5b489216db042180b14a8ef 100644 --- a/test/normalization/test_groupnorm_fp16.cpp +++ b/test/normalization/test_groupnorm_fp16.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #include "gtest/gtest.h" #include "profiler/profile_groupnorm_impl.hpp" diff --git a/test/normalization/test_groupnorm_fp32.cpp b/test/normalization/test_groupnorm_fp32.cpp index 3542f73a62f05f690f12270b44c92816816743c3..ec88442fc02623767f299de36f8b6c5810569cc6 100644 --- a/test/normalization/test_groupnorm_fp32.cpp +++ b/test/normalization/test_groupnorm_fp32.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #include "gtest/gtest.h" #include "profiler/profile_groupnorm_impl.hpp" diff --git a/test/normalization/test_layernorm2d_fp16.cpp b/test/normalization/test_layernorm2d_fp16.cpp index d627cbe7f1187e904ce8bfa452274af509708101..2222740fcceaa4b9dfe22a299dc0cd3e054e5d35 100644 --- a/test/normalization/test_layernorm2d_fp16.cpp +++ b/test/normalization/test_layernorm2d_fp16.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #include "gtest/gtest.h" #include "profiler/profile_layernorm_impl.hpp" diff --git a/test/normalization/test_layernorm2d_fp32.cpp b/test/normalization/test_layernorm2d_fp32.cpp index de4133aa8369b08be2fe3c3741081d121ceb59f6..30fbe06c60d8ebfcbdd87b3e27c11c601d81635b 100644 --- a/test/normalization/test_layernorm2d_fp32.cpp +++ b/test/normalization/test_layernorm2d_fp32.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #include "gtest/gtest.h" #include "profiler/profile_layernorm_impl.hpp" diff --git a/test/pool/CMakeLists.txt b/test/pool/CMakeLists.txt new file mode 100644 index 0000000000000000000000000000000000000000..fac806897a1f97e4decea74c445793599f884868 --- /dev/null +++ b/test/pool/CMakeLists.txt @@ -0,0 +1,16 @@ +add_custom_target(test_pool) + +add_gtest_executable(test_avg_pool3d_bwd test_avg_pool3d_bwd.cpp) +add_gtest_executable(test_max_pool3d_bwd test_max_pool3d_bwd.cpp) +add_gtest_executable(test_avg_pool3d_fwd test_avg_pool3d_fwd.cpp) +add_gtest_executable(test_max_pool3d_fwd test_max_pool3d_fwd.cpp) + +target_link_libraries(test_avg_pool3d_bwd PRIVATE utility device_avg_pool3d_bwd_instance) +target_link_libraries(test_max_pool3d_bwd PRIVATE utility device_max_pool_bwd_instance) +target_link_libraries(test_avg_pool3d_fwd PRIVATE utility device_pool3d_fwd_instance) +target_link_libraries(test_max_pool3d_fwd PRIVATE utility device_pool3d_fwd_instance) + +add_dependencies(test_pool test_avg_pool3d_bwd) +add_dependencies(test_pool test_max_pool3d_bwd) +add_dependencies(test_pool test_avg_pool3d_fwd) +add_dependencies(test_pool test_max_pool3d_fwd) diff --git a/test/pool/test_avg_pool3d_bwd.cpp b/test/pool/test_avg_pool3d_bwd.cpp new file mode 100644 index 0000000000000000000000000000000000000000..fbd03fdf45160873d98bc99b5484bb342f2f58a2 --- /dev/null +++ b/test/pool/test_avg_pool3d_bwd.cpp @@ -0,0 +1,74 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#include "gtest/gtest.h" +#include "profiler/profile_avg_pool3d_bwd_impl.hpp" +#include "test_pool_fwd_common.hpp" + +template +class TestAvgPool3dBwd : public ::testing::Test +{ + protected: + using DOutDataType = std::tuple_element_t<0, Tuple>; + using DInDataType = std::tuple_element_t<1, Tuple>; + using ComputeDataType = std::tuple_element_t<2, Tuple>; + using DOutLayout = std::tuple_element_t<3, Tuple>; + using DInLayout = std::tuple_element_t<4, Tuple>; + + std::vector params; + + void Run() + { + for(auto param : params) + { + bool success = + ck::profiler::profile_avg_pool3d_bwd_impl(true, + 2, + false, + false, + param.length_, + param.window_spatial_lengths_, + param.window_strides_, + param.window_dilations_, + param.input_left_pads_, + param.input_right_pads_); + EXPECT_TRUE(success); + } + } +}; + +#if defined(CK_ENABLE_FP16) && defined(CK_ENABLE_BF16) && defined(CK_ENABLE_FP32) +using KernelTypes = ::testing::Types, + std::tuple, + std::tuple>; +#elif defined(CK_ENABLE_FP16) && defined(CK_ENABLE_FP32) +using KernelTypes = ::testing::Types, + std::tuple>; +#elif defined(CK_ENABLE_BF16) && defined(CK_ENABLE_FP32) +using KernelTypes = ::testing::Types, + std::tuple>; +#elif defined(CK_ENABLE_FP16) && defined(CK_ENABLE_BF16) +using KernelTypes = ::testing::Types, + std::tuple>; +#elif defined(CK_ENABLE_FP16) +using KernelTypes = ::testing::Types>; +#elif defined(CK_ENABLE_BF16) +using KernelTypes = ::testing::Types>; +#elif defined(CK_ENABLE_FP32) +using KernelTypes = ::testing::Types>; +#endif + +TYPED_TEST_SUITE(TestAvgPool3dBwd, KernelTypes); +TYPED_TEST(TestAvgPool3dBwd, Test_Pool) +{ + // length, window_length, window_stride, window_dilation, left_pad, right_pad + this->params = {{{1, 1, 1, 1, 1}, {1, 1, 1}, {1, 1, 1}, {1, 1, 1}, {0, 0, 0}, {0, 0, 0}}, + {{2, 16, 64, 64, 64}, {4, 4, 4}, {4, 4, 4}, {2, 2, 2}, {0, 0, 0}, {0, 0, 0}}, + {{2, 32, 30, 30, 30}, {2, 2, 2}, {2, 2, 2}, {1, 1, 1}, {1, 1, 1}, {1, 1, 1}}}; + + this->Run(); +} diff --git a/test/pool_fwd/test_avg_pool3d_fwd.cpp b/test/pool/test_avg_pool3d_fwd.cpp similarity index 72% rename from test/pool_fwd/test_avg_pool3d_fwd.cpp rename to test/pool/test_avg_pool3d_fwd.cpp index 0d6b105b1b9aaaadb4ff98d4ad325f3850aa747e..fc196a8a07f2a1c529aedc023a61b64e3c67e7d4 100644 --- a/test/pool_fwd/test_avg_pool3d_fwd.cpp +++ b/test/pool/test_avg_pool3d_fwd.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #include "gtest/gtest.h" #include "profiler/profile_pool3d_fwd_impl.hpp" @@ -25,6 +25,8 @@ class TestAvgPool3dFwd : public ::testing::Test OutDataType, ComputeDataType, IndexDataType, + ck::tensor_layout::convolution::NDHWC, + ck::tensor_layout::convolution::NDHWC, ck::ReduceTensorOp::AVG, false, false>(true, @@ -34,23 +36,27 @@ class TestAvgPool3dFwd : public ::testing::Test param.length_, param.window_spatial_lengths_, param.window_strides_, + param.window_dilations_, param.input_left_pads_, param.input_right_pads_); EXPECT_TRUE(success); } } }; - +#ifdef CK_ENABLE_FP16 using KernelTypes = ::testing::Types, std::tuple>; - +#else +using KernelTypes = ::testing::Types>; +#endif TYPED_TEST_SUITE(TestAvgPool3dFwd, KernelTypes); TYPED_TEST(TestAvgPool3dFwd, Test_Pool) { - // length, window_length, window_stride, left_pad, right_pad - this->params = {{{1, 1, 1, 1, 1}, {1, 1, 1}, {1, 1, 1}, {0, 0, 0}, {0, 0, 0}}, - {{2, 16, 64, 64, 64}, {64, 64, 64}, {1, 1, 1}, {0, 0, 0}, {0, 0, 0}}, - {{2, 32, 30, 30, 30}, {2, 2, 2}, {2, 2, 2}, {1, 1, 1}, {1, 1, 1}}}; + // length, window_length, window_stride, window_dilation, left_pad, right_pad + this->params = {{{1, 1, 1, 1, 1}, {1, 1, 1}, {1, 1, 1}, {1, 1, 1}, {0, 0, 0}, {0, 0, 0}}, + {{2, 16, 64, 64, 64}, {64, 64, 64}, {1, 1, 1}, {1, 1, 1}, {0, 0, 0}, {0, 0, 0}}, + {{2, 16, 64, 64, 64}, {4, 4, 4}, {4, 4, 4}, {2, 2, 2}, {0, 0, 0}, {0, 0, 0}}, + {{2, 32, 30, 30, 30}, {2, 2, 2}, {2, 2, 2}, {1, 1, 1}, {1, 1, 1}, {1, 1, 1}}}; this->Run(); } diff --git a/test/pool/test_max_pool3d_bwd.cpp b/test/pool/test_max_pool3d_bwd.cpp new file mode 100644 index 0000000000000000000000000000000000000000..8d52bde4da932d4c715938ddaaf162d118367bd0 --- /dev/null +++ b/test/pool/test_max_pool3d_bwd.cpp @@ -0,0 +1,79 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#include "gtest/gtest.h" +#include "profiler/profile_max_pool3d_bwd_impl.hpp" +#include "test_pool_fwd_common.hpp" + +template +class TestMaxPool3dBwd : public ::testing::Test +{ + protected: + using DOutDataType = std::tuple_element_t<0, Tuple>; + using DInDataType = std::tuple_element_t<1, Tuple>; + using IndexDataType = std::tuple_element_t<2, Tuple>; + + using InDataType = DInDataType; + using OutDataType = DOutDataType; + + std::vector params; + + void Run() + { + for(auto param : params) + { + bool success = + ck::profiler::profile_max_pool3d_bwd_impl(true, + 2, + false, + false, + param.length_, + param.window_spatial_lengths_, + param.window_strides_, + param.window_dilations_, + param.input_left_pads_, + param.input_right_pads_); + EXPECT_TRUE(success); + } + } +}; + +#if defined(CK_ENABLE_FP16) && defined(CK_ENABLE_BF16) && defined(CK_ENABLE_FP32) +using KernelTypes = ::testing::Types, + std::tuple, + std::tuple>; +#elif defined(CK_ENABLE_FP16) && defined(CK_ENABLE_FP32) +using KernelTypes = ::testing::Types, + std::tuple>; +#elif defined(CK_ENABLE_BF16) && defined(CK_ENABLE_FP32) +using KernelTypes = ::testing::Types, + std::tuple>; +#elif defined(CK_ENABLE_FP16) && defined(CK_ENABLE_BF16) +using KernelTypes = ::testing::Types, + std::tuple>; +#elif defined(CK_ENABLE_FP16) +using KernelTypes = ::testing::Types>; +#elif defined(CK_ENABLE_BF16) +using KernelTypes = ::testing::Types>; +#elif defined(CK_ENABLE_FP32) +using KernelTypes = ::testing::Types>; +#endif + +TYPED_TEST_SUITE(TestMaxPool3dBwd, KernelTypes); +TYPED_TEST(TestMaxPool3dBwd, Test_Pool) +{ + // length, window_length, window_stride, window_dilation, left_pad, right_pad + this->params = {{{1, 1, 1, 1, 1}, {1, 1, 1}, {1, 1, 1}, {1, 1, 1}, {0, 0, 0}, {0, 0, 0}}, + {{2, 16, 64, 64, 64}, {4, 4, 4}, {4, 4, 4}, {2, 2, 2}, {0, 0, 0}, {0, 0, 0}}, + {{2, 32, 30, 30, 30}, {2, 2, 2}, {2, 2, 2}, {1, 1, 1}, {1, 1, 1}, {1, 1, 1}}}; + + // this->params = {{{2, 32, 30, 30, 30}, {2, 2, 2}, {2, 2, 2}, {1, 1, 1}, {1, 1, 1}, {1, 1, + // 1}}}; + + this->Run(); +} diff --git a/test/pool_fwd/test_max_pool3d_fwd.cpp b/test/pool/test_max_pool3d_fwd.cpp similarity index 74% rename from test/pool_fwd/test_max_pool3d_fwd.cpp rename to test/pool/test_max_pool3d_fwd.cpp index f084dd9cbf9f0cffce33947c63c2f721442264ef..7189f1b10421118503e1dc00834af2fd488e2cd6 100644 --- a/test/pool_fwd/test_max_pool3d_fwd.cpp +++ b/test/pool/test_max_pool3d_fwd.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #include "gtest/gtest.h" #include "profiler/profile_pool3d_fwd_impl.hpp" @@ -26,6 +26,8 @@ class TestMaxPool3dFwd : public ::testing::Test OutDataType, ComputeDataType, IndexDataType, + ck::tensor_layout::convolution::NDHWC, + ck::tensor_layout::convolution::NDHWC, ck::ReduceTensorOp::MAX, false, false>(true, @@ -35,6 +37,7 @@ class TestMaxPool3dFwd : public ::testing::Test param.length_, param.window_spatial_lengths_, param.window_strides_, + param.window_dilations_, param.input_left_pads_, param.input_right_pads_); EXPECT_TRUE(success); @@ -44,6 +47,8 @@ class TestMaxPool3dFwd : public ::testing::Test OutDataType, ComputeDataType, IndexDataType, + ck::tensor_layout::convolution::NDHWC, + ck::tensor_layout::convolution::NDHWC, ck::ReduceTensorOp::MAX, false, true>(true, @@ -53,6 +58,7 @@ class TestMaxPool3dFwd : public ::testing::Test param.length_, param.window_spatial_lengths_, param.window_strides_, + param.window_dilations_, param.input_left_pads_, param.input_right_pads_); EXPECT_TRUE(success); @@ -60,16 +66,21 @@ class TestMaxPool3dFwd : public ::testing::Test } }; +#ifdef CK_ENABLE_FP16 using KernelTypes = - ::testing::Types, std::tuple>; + ::testing::Types, std::tuple>; +#else +using KernelTypes = ::testing::Types>; +#endif TYPED_TEST_SUITE(TestMaxPool3dFwd, KernelTypes); TYPED_TEST(TestMaxPool3dFwd, Test_Pool) { - // length, window_length, window_stride, left_pad, right_pad - this->params = {{{1, 1, 1, 1, 1}, {1, 1, 1}, {1, 1, 1}, {0, 0, 0}, {0, 0, 0}}, - {{2, 16, 64, 64, 64}, {64, 64, 64}, {1, 1, 1}, {0, 0, 0}, {0, 0, 0}}, - {{2, 32, 30, 30, 30}, {2, 2, 2}, {2, 2, 2}, {1, 1, 1}, {1, 1, 1}}}; + // length, window_length, window_stride, window_dilation, left_pad, right_pad + this->params = {{{1, 1, 1, 1, 1}, {1, 1, 1}, {1, 1, 1}, {1, 1, 1}, {0, 0, 0}, {0, 0, 0}}, + {{2, 16, 64, 64, 64}, {64, 64, 64}, {1, 1, 1}, {1, 1, 1}, {0, 0, 0}, {0, 0, 0}}, + {{2, 16, 64, 64, 64}, {4, 4, 4}, {4, 4, 4}, {2, 2, 2}, {0, 0, 0}, {0, 0, 0}}, + {{2, 32, 30, 30, 30}, {2, 2, 2}, {2, 2, 2}, {1, 1, 1}, {1, 1, 1}, {1, 1, 1}}}; this->Run(); } diff --git a/test/pool_fwd/test_pool_fwd_common.hpp b/test/pool/test_pool_fwd_common.hpp similarity index 71% rename from test/pool_fwd/test_pool_fwd_common.hpp rename to test/pool/test_pool_fwd_common.hpp index a78785171d9dc53da9f85782bb312176e1012bd6..5917a27e56cb955ed97f074f30805053a9ebe9dd 100644 --- a/test/pool_fwd/test_pool_fwd_common.hpp +++ b/test/pool/test_pool_fwd_common.hpp @@ -1,24 +1,28 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #include "gtest/gtest.h" #include "ck/ck.hpp" -using F16 = ck::half_t; -using F32 = float; -using I32 = int32_t; +using F16 = ck::half_t; +using BF16 = ck::bhalf_t; +using F32 = float; +using I32 = int32_t; using ck::index_t; +using NDHWC = ck::tensor_layout::convolution::NDHWC; struct PoolingParam { PoolingParam(const std::vector& length, const std::vector& window_spatial_lengths, const std::vector& window_strides, + const std::vector& window_dilations, const std::vector& input_left_pads, const std::vector& input_right_pads) : length_(length), window_spatial_lengths_(window_spatial_lengths), window_strides_(window_strides), + window_dilations_(window_dilations), input_left_pads_(input_left_pads), input_right_pads_(input_right_pads) { @@ -26,6 +30,7 @@ struct PoolingParam std::vector length_; std::vector window_spatial_lengths_; std::vector window_strides_; + std::vector window_dilations_; std::vector input_left_pads_; std::vector input_right_pads_; }; diff --git a/test/pool_fwd/CMakeLists.txt b/test/pool_fwd/CMakeLists.txt deleted file mode 100644 index 6f59b95f6fce8679a022f2bb57de80983cf261a5..0000000000000000000000000000000000000000 --- a/test/pool_fwd/CMakeLists.txt +++ /dev/null @@ -1,16 +0,0 @@ -add_custom_target(test_pool_fwd) - -add_gtest_executable(test_avg_pool2d_fwd test_avg_pool2d_fwd.cpp) -add_gtest_executable(test_avg_pool3d_fwd test_avg_pool3d_fwd.cpp) -add_gtest_executable(test_max_pool2d_fwd test_max_pool2d_fwd.cpp) -add_gtest_executable(test_max_pool3d_fwd test_max_pool3d_fwd.cpp) - -target_link_libraries(test_avg_pool2d_fwd PRIVATE utility device_pool_fwd_instance) -target_link_libraries(test_avg_pool3d_fwd PRIVATE utility device_pool_fwd_instance) -target_link_libraries(test_max_pool2d_fwd PRIVATE utility device_pool_fwd_instance) -target_link_libraries(test_max_pool3d_fwd PRIVATE utility device_pool_fwd_instance) - -add_dependencies(test_pool_fwd test_avg_pool2d_fwd) -add_dependencies(test_pool_fwd test_avg_pool3d_fwd) -add_dependencies(test_pool_fwd test_max_pool2d_fwd) -add_dependencies(test_pool_fwd test_max_pool3d_fwd) diff --git a/test/pool_fwd/test_avg_pool2d_fwd.cpp b/test/pool_fwd/test_avg_pool2d_fwd.cpp deleted file mode 100644 index 4e5f1e0e919dc41403cacd46e7619e8e6093f5fc..0000000000000000000000000000000000000000 --- a/test/pool_fwd/test_avg_pool2d_fwd.cpp +++ /dev/null @@ -1,56 +0,0 @@ -// SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. - -#include "gtest/gtest.h" -#include "profiler/profile_pool2d_fwd_impl.hpp" -#include "test_pool_fwd_common.hpp" - -template -class TestAvgPool2dFwd : public ::testing::Test -{ - protected: - using InDataType = std::tuple_element_t<0, Tuple>; - using OutDataType = std::tuple_element_t<1, Tuple>; - using ComputeDataType = std::tuple_element_t<2, Tuple>; - using IndexDataType = std::tuple_element_t<3, Tuple>; - - std::vector params; - - void Run() - { - for(auto param : params) - { - bool success = - ck::profiler::profile_pool2d_fwd_impl(true, - 2, - false, - false, - param.length_, - param.window_spatial_lengths_, - param.window_strides_, - param.input_left_pads_, - param.input_right_pads_); - EXPECT_TRUE(success); - } - } -}; - -using KernelTypes = - ::testing::Types, std::tuple>; - -TYPED_TEST_SUITE(TestAvgPool2dFwd, KernelTypes); -TYPED_TEST(TestAvgPool2dFwd, Test_Pool) -{ - // length, window_length, window_stride, left_pad, right_pad - this->params = {{{1, 1, 1, 1}, {1, 1}, {1, 1}, {0, 0}, {0, 0}}, - {{2, 16, 64, 64}, {64, 64}, {1, 1}, {0, 0}, {0, 0}}, - {{2, 32, 30, 30}, {2, 2}, {2, 2}, {1, 1}, {1, 1}}}; - - this->Run(); -} diff --git a/test/pool_fwd/test_max_pool2d_fwd.cpp b/test/pool_fwd/test_max_pool2d_fwd.cpp deleted file mode 100644 index d16ac7fabac44156e7cca553a62228f397f9093f..0000000000000000000000000000000000000000 --- a/test/pool_fwd/test_max_pool2d_fwd.cpp +++ /dev/null @@ -1,75 +0,0 @@ -// SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. - -#include "gtest/gtest.h" -#include "profiler/profile_pool2d_fwd_impl.hpp" -#include "test_pool_fwd_common.hpp" - -template -class TestMaxPool2dFwd : public ::testing::Test -{ - protected: - using InDataType = std::tuple_element_t<0, Tuple>; - using OutDataType = std::tuple_element_t<1, Tuple>; - using ComputeDataType = std::tuple_element_t<2, Tuple>; - using IndexDataType = std::tuple_element_t<3, Tuple>; - - std::vector params; - - void Run() - { - for(auto param : params) - { - // max pool - bool success = - ck::profiler::profile_pool2d_fwd_impl(true, - 2, - false, - false, - param.length_, - param.window_spatial_lengths_, - param.window_strides_, - param.input_left_pads_, - param.input_right_pads_); - EXPECT_TRUE(success); - - // max pool + index - success = ck::profiler::profile_pool2d_fwd_impl(true, - 2, - false, - false, - param.length_, - param.window_spatial_lengths_, - param.window_strides_, - param.input_left_pads_, - param.input_right_pads_); - EXPECT_TRUE(success); - } - } -}; - -using KernelTypes = - ::testing::Types, std::tuple>; - -TYPED_TEST_SUITE(TestMaxPool2dFwd, KernelTypes); -TYPED_TEST(TestMaxPool2dFwd, Test_Pool) -{ - // length, window_length, window_stride, left_pad, right_pad - this->params = {{{1, 1, 1, 1}, {1, 1}, {1, 1}, {0, 0}, {0, 0}}, - {{2, 16, 64, 64}, {64, 64}, {1, 1}, {0, 0}, {0, 0}}, - {{2, 32, 30, 30}, {2, 2}, {2, 2}, {1, 1}, {1, 1}}}; - - this->Run(); -} diff --git a/test/reduce/reduce_no_index.cpp b/test/reduce/reduce_no_index.cpp index 3f4d0676b4da84751dd74cd4e76266059c056d67..1ab452442d0b219820e208420a4c4c04a019e9a1 100644 --- a/test/reduce/reduce_no_index.cpp +++ b/test/reduce/reduce_no_index.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #include diff --git a/test/reduce/reduce_with_index.cpp b/test/reduce/reduce_with_index.cpp index c616a68e741e14119d70e0d1d4ba508e371dea78..0301669c5a7578b7c6a0c6ba160685b13295b492 100644 --- a/test/reduce/reduce_with_index.cpp +++ b/test/reduce/reduce_with_index.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #include diff --git a/test/reference_conv_fwd/reference_conv_fwd.cpp b/test/reference_conv_fwd/reference_conv_fwd.cpp index 1f9ba0064cb3b22c394c75d5119f71a188d11a5a..b3328e4b365268af641a6ecb1a65ded872f94ddd 100644 --- a/test/reference_conv_fwd/reference_conv_fwd.cpp +++ b/test/reference_conv_fwd/reference_conv_fwd.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #include #include diff --git a/test/softmax/test_softmax_interface.cpp b/test/softmax/test_softmax_interface.cpp index 8cac0ba0f52b035dae2fb52f0b0346d3267e5b7b..25f666f0ea99eb12e0fa248fdd573aebed0de7a6 100644 --- a/test/softmax/test_softmax_interface.cpp +++ b/test/softmax/test_softmax_interface.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #include #include diff --git a/test/softmax/test_softmax_rank3.cpp b/test/softmax/test_softmax_rank3.cpp index 5691ee3f6cd5be1bd8ff1075cb57fcfd86aa8322..a8b950ce6b4775bd25f98b5cdd82d413fe121332 100644 --- a/test/softmax/test_softmax_rank3.cpp +++ b/test/softmax/test_softmax_rank3.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #include #include @@ -10,10 +10,10 @@ template using I = ck::Number; - +#ifdef CK_ENABLE_FP16 using F16 = ck::half_t; +#endif using F32 = float; -using I8 = int8_t; template class TestSoftmax : public ck::TestSoftmax @@ -23,9 +23,10 @@ class TestSoftmax : public ck::TestSoftmax // clang-format off using KernelTypes = ::testing::Types< // InDataType, AccDataType, OutDataType, Rank +#ifdef CK_ENABLE_FP16 std::tuple< F16, F32, F16, I<3>>, - std::tuple< F32, F32, F32, I<3>>, - std::tuple< I8, F32, I8, I<3>> +#endif + std::tuple< F32, F32, F32, I<3>> >; // clang-format on diff --git a/test/softmax/test_softmax_rank4.cpp b/test/softmax/test_softmax_rank4.cpp index f0b22df25ebdfaaecbf4548ec2b44fdcd477aae4..cbd790ac9b51fa1b489cd3f75f3afce40fcbb651 100644 --- a/test/softmax/test_softmax_rank4.cpp +++ b/test/softmax/test_softmax_rank4.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #include #include @@ -10,10 +10,10 @@ template using I = ck::Number; - +#ifdef CK_ENABLE_FP16 using F16 = ck::half_t; +#endif using F32 = float; -using I8 = int8_t; template class TestSoftmax : public ck::TestSoftmax @@ -23,9 +23,10 @@ class TestSoftmax : public ck::TestSoftmax // clang-format off using KernelTypes = ::testing::Types< // InDataType, AccDataType, OutDataType, Rank +#ifdef CK_ENABLE_FP16 std::tuple< F16, F32, F16, I<4>>, - std::tuple< F32, F32, F32, I<4>>, - std::tuple< I8, F32, I8, I<4>> +#endif + std::tuple< F32, F32, F32, I<4>> >; // clang-format on diff --git a/test/softmax/test_softmax_util.hpp b/test/softmax/test_softmax_util.hpp index 40b300cf9927de3556c586805923de9ab97db466..1409af8453b3a9d722fd5074d311f7066f295c65 100644 --- a/test/softmax/test_softmax_util.hpp +++ b/test/softmax/test_softmax_util.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #pragma once @@ -61,8 +61,92 @@ class TestSoftmax : public ::testing::Test int init_method = 1; // integer value initialization bool log = false; std::vector strides; // intenionally empty, to get packed layout. - bool pass = ck::profiler::profile_softmax_impl( - verify_, init_method, log, bench_, in_length, strides, reduce_dims, alpha, beta); + bool pass = false; + + if constexpr(Rank == 3) + { + if(reduce_dims.size() == 1) + pass = ck::profiler:: + profile_softmax_impl(verify_, + init_method, + log, + bench_, + in_length, + strides, + reduce_dims, + alpha, + beta); + else if(reduce_dims.size() == 2) + pass = ck::profiler:: + profile_softmax_impl(verify_, + init_method, + log, + bench_, + in_length, + strides, + reduce_dims, + alpha, + beta); + else if(reduce_dims.size() == 3) + pass = ck::profiler:: + profile_softmax_impl(verify_, + init_method, + log, + bench_, + in_length, + strides, + reduce_dims, + alpha, + beta); + } + else if constexpr(Rank == 4) + { + if(reduce_dims.size() == 1) + pass = ck::profiler:: + profile_softmax_impl(verify_, + init_method, + log, + bench_, + in_length, + strides, + reduce_dims, + alpha, + beta); + else if(reduce_dims.size() == 2) + pass = ck::profiler:: + profile_softmax_impl(verify_, + init_method, + log, + bench_, + in_length, + strides, + reduce_dims, + alpha, + beta); + else if(reduce_dims.size() == 3) + pass = ck::profiler:: + profile_softmax_impl(verify_, + init_method, + log, + bench_, + in_length, + strides, + reduce_dims, + alpha, + beta); + else if(reduce_dims.size() == 4) + pass = ck::profiler:: + profile_softmax_impl(verify_, + init_method, + log, + bench_, + in_length, + strides, + reduce_dims, + alpha, + beta); + }; + EXPECT_TRUE(pass); } diff --git a/test/space_filling_curve/space_filling_curve.cpp b/test/space_filling_curve/space_filling_curve.cpp index c7f6759e819512f55c17463ffdb503e55bb8c59e..a192ecb28f1b6c67fa7ec27790349e041c6718fb 100644 --- a/test/space_filling_curve/space_filling_curve.cpp +++ b/test/space_filling_curve/space_filling_curve.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #include #include diff --git a/test/wmma_op/wmma_op.cpp b/test/wmma_op/wmma_op.cpp index 761c15f1dd8a9d22de14a2998d73a3877b1f3003..47d8c7ed6f35236192cc9c4cfacde2e754560241 100644 --- a/test/wmma_op/wmma_op.cpp +++ b/test/wmma_op/wmma_op.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #include #include diff --git a/test/wmma_op/wmma_op_util.hpp b/test/wmma_op/wmma_op_util.hpp index c70e6a407de6fce20df64dfa29f9aa19de1d9e30..49782bce6e21d4a250b0d07526abf4e39ac57b9d 100644 --- a/test/wmma_op/wmma_op_util.hpp +++ b/test/wmma_op/wmma_op_util.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #pragma once